题面:
题意:
定义字符串函数 f(S,x,y)(1<=x<=y<=n) 返回一个长度为 y−x+1的串,第 k 位是 maxi=x...x+k−1Si
设集合 A={f(f(S,x1,y1),x2−x1+1,y2−x1+1)∣1<=x1<=x2<=y2<=y1<=n}
求集合A的大小,其中 N<=100000,字符集大小<=10
题解:
我们设 f操作为将当前串的第 k 位变为 maxi=1...1+k−1Si的操作
相当于我们对于S的任一子串做一次 f,再对这些做过一次 f子串的任一子串做一次 f
我们发现对做过 f子串的任一子串再做一次 f起不到任何作用,因为这时子串的子串中,前面的字符一定要比后面的字符小。
所以转换为,我们对于S的任一子串做一次 f,对于这些做过一次 f子串的所有子串中,本质不同的子串有多少个。
我们又发现,取子串 [i,x]做 f操作后的子串一定都包含在取子串 [i,x+1]做 f操作后的子串中。
问题等价为, f(S,i,n) 这n个串本质不同的子串的个数。
可以证明如果对于S的子串 [i,n]都做 f操作后,从后往前插入字典树中,长度不会超过10n,然后我们在这棵字典树上跑一棵广义后缀自动机即可。
只需要知道每个字符后面第一个大于等于它的字符出现在哪里即可(单调栈),不用把字典树建出来。
代码:
#include<iostream>
#include<cstdio>
#include<cstdlib>
#include<algorithm>
#include<cstring>
#include<cmath>
#include<string>
#include<queue>
#include<bitset>
#include<map>
#include<unordered_map>
#include<set>
#define ui unsigned int
#define ll long long
#define llu unsigned ll
#define ld long double
#define pr make_pair
#define pb push_back
#define lc (cnt<<1)
#define rc (cnt<<1|1)
#define len(x) (t[(x)].r-t[(x)].l+1)
#define tmid ((l+r)>>1)
using namespace std;
const int inf=0x3f3f3f3f;
const ll lnf=0x3f3f3f3f3f3f3f3f;
const double dnf=1e18;
const int mod=1e9+7;
const double eps=1e-8;
const double pi=acos(-1.0);
const int hp=13331;
const int maxn=1000100;
const int maxm=100100;
const int up=100100;
char str[maxn];
int st[maxn],p[maxn],top=0;
struct Sam
{
int last,cnt;
int n,k;
int nt[maxn<<1][10],fa[maxn<<1];
int len[maxn<<1],sum[maxn<<1];
int x[maxn<<1],y[maxn<<1];
void init(void)
{
last=1;
cnt=1;
fa[1]=0;
len[1]=0;
}
void _insert(int c)
{
if(nt[last][c])
{
int p=last,q=nt[p][c];
if(len[q]==len[p]+1) last=q;
else
{
int nowq=++cnt;
len[nowq]=len[p]+1;
memcpy(nt[nowq],nt[q],sizeof(nt[q]));
fa[nowq]=fa[q];
fa[q]=nowq;
while(p&&nt[p][c]==q) nt[p][c]=nowq,p=fa[p];
last=nowq;
}
}
else
{
int nowp=++cnt,p=last;
len[nowp]=len[last]+1;
while(p&&!nt[p][c]) nt[p][c]=nowp,p=fa[p];
if(!p) fa[nowp]=1;
else
{
int q=nt[p][c];
if(len[q]==len[p]+1) fa[nowp]=q;
else
{
int nowq=++cnt;
len[nowq]=len[p]+1;
memcpy(nt[nowq],nt[q],sizeof(nt[q]));
fa[nowq]=fa[q];
fa[nowp]=fa[q]=nowq;
while(p&&nt[p][c]==q) nt[p][c]=nowq,p=fa[p];
}
}
last=nowp;
}
return ;
}
}sam;
int main(void)
{
scanf("%s",str+1);
int n=strlen(str+1);
reverse(str+1,str+n+1);
sam.init();
p[0]=1;
for(int i=1;i<=n;i++)
{
while(top>0&&str[st[top]]<str[i]) top--;
sam.last=p[st[top]];
for(int j=st[top]+1;j<=i;j++)
sam._insert(str[i]-'a');
st[++top]=i;
p[i]=sam.last;
}
ll ans=0;
for(int i=1;i<=sam.cnt;i++)
ans=ans+sam.len[i]-sam.len[sam.fa[i]];
printf("%lld\n",ans);
return 0;
}