思路
观察题目中的式子,可以发现前两项是定值。所以只需要求出最后一项就行了。
然后题目就转化为了求字符串中所有后缀的\(lcp\)长度之和。
可以想到用后缀数组。在后缀数组上两个后缀的\(lcp\)长度表现为两个后缀排名之间的\(height\)的最小值。
所以现在问题就又转化为了在\(height\)数组上求所有区间最小值之和。
这个可以用单调栈做到。
代码
/*
* @Author: wxyww
* @Date: 2019-01-30 19:14:49
* @Last Modified time: 2019-01-30 20:49:38
*/
#include<cstdio>
#include<iostream>
#include<cstdlib>
#include<cstring>
#include<cmath>
#include<ctime>
#include<bitset>
using namespace std;
typedef long long ll;
#define int ll
const int N = 500010;
ll read() {
ll x=0,f=1;char c=getchar();
while(c<'0'||c>'9') {
if(c=='-') f=-1;
c=getchar();
}
while(c>='0'&&c<='9') {
x=x*10+c-'0';
c=getchar();
}
return x*f;
}
int sa[N],rk[N],height[N],c[N],x[N],y[N];
char s[N];
int m,n;
void get_sa() {
for(int i = 1;i <= m;++i) c[i] = 0;
for(int i = 1;i <= n;++i) ++c[x[i] = s[i]];
for(int i = 2;i <= m;++i) c[i] += c[i - 1];
for(int i = n;i >= 1;--i) sa[c[x[i]]--] = i;
for(int k = 1;k <= n;k <<= 1) {
int num = 0;
for(int i = n - k + 1;i <= n; ++i) y[++num] = i;
for(int i = 1;i <= n;++i) if(sa[i] > k) y[++num] = sa[i] - k;
for(int i = 2;i <= m;++i) c[i] = 0;
for(int i = 1;i <= n;++i) ++c[x[i]];
for(int i = 1;i <= m;++i) c[i] += c[i - 1];
for(int i = n;i >= 1;--i) sa[c[x[y[i]]]--] = y[i];
swap(x,y);
num = 0;
x[sa[1]] = ++num;
for(int i = 2;i <= n;++i) {
if(y[sa[i]] == y[sa[i - 1]] && y[sa[i] + k] == y[sa[i - 1] + k]) x[sa[i]] = num;
else x[sa[i]] = ++num;
}
if(num == n) break;
m = num;
}
}
int h[N],q[N],tail;
void get_height() {
for(int i = 1;i <= n;++i) rk[sa[i]] = i;
int k = 0;
for(int i = 1;i <= n;++i) {
if(rk[i] == 1) continue;
if(k) --k;
int j = sa[rk[i] - 1];
while(j + k <= n && i + k <= n && s[j + k] == s[i + k]) ++k;
h[i] = height[rk[i]] = k;
}
}
ll work() {
int tail = 0;
ll now = 0,ans = 0;
for(int i = 1; i <= n;++i) {
while(height[i] < height[q[tail]] && tail) now -= height[q[tail]] * (q[tail] - q[tail - 1]),tail--;
q[++tail] = i;
now += height[i] * (q[tail] - q[tail - 1]);
ans += now;
}
return ans;
}
int get(int x,int y) {
int ans = 1e9;
int l = min(rk[x],rk[y]),r = max(rk[x],rk[y]);
for(int i = l + 1;i <= r;++i) ans = min(ans,height[i]);
return ans;
}
signed main() {
scanf("%s",s + 1);
n = strlen(s + 1);
m = 'z';
get_sa();
get_height();
ll ans = 0;
for(int i = 1;i <= n;++i)
ans += i * (i - 1) + i * (n - i);
ll LC = 2ll * work();
cout<<ans - LC;
return 0;
}