题目链接

思路

观察题目中的式子,可以发现前两项是定值。所以只需要求出最后一项就行了。
然后题目就转化为了求字符串中所有后缀的\(lcp\)长度之和。
可以想到用后缀数组。在后缀数组上两个后缀的\(lcp\)长度表现为两个后缀排名之间的\(height\)的最小值。
所以现在问题就又转化为了在\(height\)数组上求所有区间最小值之和。
这个可以用单调栈做到。

代码

/*
* @Author: wxyww
* @Date:   2019-01-30 19:14:49
* @Last Modified time: 2019-01-30 20:49:38
*/
#include<cstdio>
#include<iostream>
#include<cstdlib>
#include<cstring>
#include<cmath>
#include<ctime>
#include<bitset>
using namespace std;
typedef long long ll;
#define int ll
const int N = 500010;
ll read() {
    ll x=0,f=1;char c=getchar();
    while(c<'0'||c>'9') {
        if(c=='-') f=-1;
        c=getchar();
    }
    while(c>='0'&&c<='9') {
        x=x*10+c-'0';
        c=getchar();
    }
    return x*f;
}
int sa[N],rk[N],height[N],c[N],x[N],y[N];
char s[N];
int m,n;
void get_sa() {
   for(int i = 1;i <= m;++i) c[i] = 0;
   for(int i = 1;i <= n;++i) ++c[x[i] = s[i]];
   for(int i = 2;i <= m;++i) c[i] += c[i - 1]; 
   for(int i = n;i >= 1;--i) sa[c[x[i]]--] = i;
   for(int k = 1;k <= n;k <<= 1) {
      int num = 0;
      for(int i = n - k + 1;i <= n; ++i) y[++num] = i;
      for(int i = 1;i <= n;++i) if(sa[i] > k) y[++num] = sa[i] - k;
      for(int i = 2;i <= m;++i) c[i] = 0;
      for(int i = 1;i <= n;++i) ++c[x[i]];
      for(int i = 1;i <= m;++i) c[i] += c[i - 1];
      for(int i = n;i >= 1;--i) sa[c[x[y[i]]]--] = y[i];
      swap(x,y);
      num = 0;
      x[sa[1]] = ++num;
      for(int i = 2;i <= n;++i) {
         if(y[sa[i]] == y[sa[i - 1]] && y[sa[i] + k] == y[sa[i - 1] + k]) x[sa[i]] = num;
         else x[sa[i]] = ++num;
      }
      if(num == n) break;
      m = num;
   }
}
int h[N],q[N],tail;
void get_height() {
    for(int i = 1;i <= n;++i) rk[sa[i]] = i;
    int k =  0;
    for(int i = 1;i <= n;++i) {
        if(rk[i] == 1) continue;
        if(k) --k;
        int j = sa[rk[i] - 1];
        while(j + k <= n && i + k <= n && s[j + k] == s[i + k]) ++k;
        h[i] = height[rk[i]] = k;
    }
}
ll work() {
    int tail = 0;
    ll now = 0,ans = 0;
    for(int i = 1; i <= n;++i) {
        while(height[i] < height[q[tail]] && tail) now -= height[q[tail]] * (q[tail] - q[tail - 1]),tail--;
        q[++tail] = i;
        now += height[i] * (q[tail] - q[tail - 1]);
        ans += now;
    }
    return ans;
}
int get(int x,int y) {
    int ans = 1e9;
    int l = min(rk[x],rk[y]),r = max(rk[x],rk[y]);
    for(int i = l + 1;i <= r;++i) ans = min(ans,height[i]);
    return ans;
}
signed main() {
    scanf("%s",s + 1);
    n = strlen(s + 1);
    m = 'z';
    get_sa();
    get_height();

    ll ans = 0;
    for(int i = 1;i <= n;++i)
        ans += i * (i - 1) + i * (n - i);
    ll LC = 2ll * work();
    cout<<ans - LC;
    return 0;
}