题面:
题意:
对于n个字符串,需要求出第 i 个字符串的前缀和第 j 个后缀的最长的相等长度。
然后求 ∑len(i,j)2。
官方题解:
①、hash
#include<iostream>
#include<cstdio>
#include<cstdlib>
#include<algorithm>
#include<cstring>
#include<cmath>
#include<string>
#include<queue>
#include<bitset>
#include<map>
#include<unordered_map>
#include<set>
#define ui unsigned int
#define ll long long
#define llu unsigned ll
#define ld long double
#define pr make_pair
#define pb push_back
#define lc (cnt<<1)
#define rc (cnt<<1|1)
#define len(x) (t[(x)].r-t[(x)].l+1)
#define tmid ((l+r)>>1)
using namespace std;
const int inf=0x3f3f3f3f;
const ll lnf=0x3f3f3f3f3f3f3f3f;
const double dnf=1e18;
const int mod=998244353;
const double eps=1e-8;
const double pi=acos(-1.0);
const int hp=13331;
const int maxn=1000100;
const int maxm=100100;
const int up=100000;
llu p[maxn];
int nt[maxn],cnt[maxn];
char str[maxn];
string s[maxm];
map<llu,int>mp;
void init(void)
{
p[0]=1;
for(int i=1;i<maxn;i++)
p[i]=p[i-1]*hp;
}
void getnt(int k)
{
int len=s[k].size();
nt[0]=-1;
for(int i=1,j=-1;i<len;i++)
{
while(j>=0&&s[k][i]!=s[k][j+1]) j=nt[j];
if(s[k][i]==s[k][j+1]) j++;
nt[i]=j;
}
}
int main(void)
{
init();
int n;
scanf("%d",&n);
for(int i=1;i<=n;i++)
{
scanf("%s",str+1);
s[i]=str+1;
llu pre=0;
int len=strlen(str+1);
for(int j=len;j>=1;j--)
{
pre+=p[len-j]*str[j];
mp[pre]++;
}
}
ll ans=0;
for(int i=1;i<=n;i++)
{
getnt(i);
llu pre=0;
for(int j=0;j<s[i].size();j++)
{
pre=pre*hp+s[i][j];
cnt[j]=mp[pre];
}
//只算最长的,需要去重
for(int j=0;j<s[i].size();j++)
{
if(nt[j]>=0)
cnt[nt[j]]-=cnt[j];
}
for(int j=0;j<s[i].size();j++)
ans=(ans+1ll*(j+1)*(j+1)%mod*cnt[j])%mod;
printf("%lld\n",ans);
return 0;
}
②、后缀自动机也能做。
我们只对于每个串的末尾位置加上1,然后在parent树上,末尾节点的fa节点都是该串的后缀。
将所有的串都插入后缀自动机之后,做一遍拓扑序统计每个节点的次数。
然后对于每一个前缀,遍历求出cnt[ i ] 即可(也需要去重)。
哈哈,假广义后缀自动机过不去。。。。
#include<iostream>
#include<cstdio>
#include<cstdlib>
#include<algorithm>
#include<cstring>
#include<cmath>
#include<string>
#include<queue>
#include<bitset>
#include<map>
#include<unordered_map>
#include<set>
#define ui unsigned int
#define ll long long
#define llu unsigned ll
#define ld long double
#define pr make_pair
#define pb push_back
#define lc (cnt<<1)
#define rc (cnt<<1|1)
#define len(x) (t[(x)].r-t[(x)].l+1)
#define tmid ((l+r)>>1)
using namespace std;
const int inf=0x3f3f3f3f;
const ll lnf=0x3f3f3f3f3f3f3f3f;
const double dnf=1e18;
const int mod=998244353;
const double eps=1e-8;
const double pi=acos(-1.0);
const int hp=13331;
const int maxn=1000100;
const int maxm=100100;
const int up=100000;
const int root=1;
struct Sam
{
int last,cnt;
int n;
int nt[maxn<<1][26],fa[maxn<<1];
int len[maxn<<1],sum[maxn<<1];
int x[maxn<<1],y[maxn<<1];
int ha[maxn<<1],f[maxn<<1];
int net[maxn],res[maxn];
string s[maxm];
char str[maxn];
void getnt(int k)
{
int len=s[k].size();
net[0]=-1;
for(int i=1,j=-1;i<len;i++)
{
while(j>=0&&s[k][i]!=s[k][j+1]) j=net[j];
if(s[k][i]==s[k][j+1]) j++;
net[i]=j;
}
}
void init(void)
{
last=root;
cnt=1;
fa[1]=0;
len[1]=0;
}
void _insert(int c)
{
if(nt[last][c])
{
int p=last,q=nt[p][c];
if(len[q]==len[p]+1) last=q;
else
{
int nowq=++cnt;
len[nowq]=len[p]+1;
memcpy(nt[nowq],nt[q],sizeof(nt[q]));
fa[nowq]=fa[q];
fa[q]=nowq;
while(p&&nt[p][c]==q) nt[p][c]=nowq,p=fa[p];
last=nowq;
}
}
else
{
int nowp=++cnt,p=last;
len[nowp]=len[last]+1;
while(p&&!nt[p][c]) nt[p][c]=nowp,p=fa[p];
if(!p) fa[nowp]=root;
else
{
int q=nt[p][c];
if(len[q]==len[p]+1) fa[nowp]=q;
else
{
int nowq=++cnt;
len[nowq]=len[p]+1;
memcpy(nt[nowq],nt[q],sizeof(nt[q]));
fa[nowq]=fa[q];
fa[nowp]=fa[q]=nowq;
while(p&&nt[p][c]==q) nt[p][c]=nowq,p=fa[p];
}
}
last=nowp;
}
return ;
}
void _count(void)
{
memset(x,0,sizeof(x));
for(int i=1;i<=cnt;i++) x[len[i]]++;
for(int i=1;i<=cnt;i++) x[i]+=x[i-1];
for(int i=1;i<=cnt;i++) y[x[len[i]]--]=i;
for(int i=cnt;i>=1;i--)
sum[fa[y[i]]]+=sum[y[i]];
}
void creat(void)
{
init();
for(int i=1;i<=n;i++)
{
scanf("%s",str);
s[i]=string(str);
last=root;
int len=strlen(str);
for(int j=0;j<len;j++)
_insert(str[j]-'a');
sum[last]++;
}
_count();
return ;
}
ll get_ans(void)
{
ll ans=0;
for(int i=1;i<=n;i++)
{
getnt(i);
int p=root;
for(int j=0;j<s[i].size();j++)
{
p=nt[p][s[i][j]-'a'];
res[j]=sum[p];
}
for(int j=0;j<s[i].size();j++)
{
if(net[j]>=0)
res[net[j]]-=res[j];
}
for(int j=0;j<s[i].size();j++)
ans=(ans+1ll*(j+1)*(j+1)%mod*res[j])%mod;
}
return ans;
}
}sam;
int main(void)
{
scanf("%d",&sam.n);
sam.creat();
printf("%lld\n",sam.get_ans());
return 0;
}