Megumi With String

这题我T了40次左右。。。拿着别人的AC代码双向修改,我的一直T,别人的一直A。。。甚至感觉除了变量名不一样,其他的都完全一样了,还是T
噩梦经历
最后发现是初始化函数写跪了

题意:给定一个已知串 S S S,再给出另外一个串(随机)的长度,求原串在每次尾部增加节点后另外一个串的价值(价值定义见题面吧)
思路

  1. 由于原串的操作是尾部增添字符,显然与后缀自动机一致,因此考虑建立后缀自动机
  2. 由于一个长度为 i i i S S S的子串,在另外一个串中出现的期望次数为: n + 1 i 2 6 i \frac{n+1-i}{26^i} 26in+1i,与串具体是什么无关,只与长度有关(可仔细思考或大致计算一下,这点挺巧妙的)。因此我们可以先预处理出长度为 i i i S S S的子串对答案的贡献(本来是求另外一个串的价值,现在却变成了 S S S串的子串对价值的贡献,也挺好玩的),然后将后缀自动机上所有节点的贡献都加上去即可
  3. 而每加入一个新的字符,就额外计算一下新的节点对答案的贡献即可(不用计算虚节点,可以认为虚节点是之前已经计算过的)。

题面描述

#include "bits/stdc++.h"
#define hhh printf("hhh\n")
#define see(x) (cerr<<(#x)<<'='<<(x)<<endl)
using namespace std;
typedef long long ll;
typedef pair<int,int> pr;
inline int read() {int x=0;char c=getchar();while(c<'0'||c>'9')c=getchar();while(c>='0'&&c<='9')x=x*10+c-'0',c=getchar();return x;}

const int maxn = 4e5+10;
const int inf = 0x3f3f3f3f;
const int mod = 998244353;
const double eps = 1e-7;

int l, k, n, m;
char s[maxn];
int a[maxn];
ll f[maxn], f0[maxn][51], ff[maxn], ans;
int ch[maxn][26], len[maxn], fa[maxn];
int last=1, tot=1;

void add(int c) {
    int p=last, np=last=++tot;
    len[np]=len[p]+1;
    for(; p&&!ch[p][c]; p=fa[p]) ch[p][c]=np;
    if(!p) fa[np]=1;
    else {
        int q=ch[p][c];
        if(len[q]==len[p]+1) fa[np]=q;
        else {
            int nq=++tot;
            fa[nq]=fa[q], len[nq]=len[p]+1;
            memcpy(ch[nq],ch[q],104);
            fa[q]=fa[np]=nq;
            for(; p&&ch[p][c]==q; p=fa[p]) ch[p][c]=nq;
        }
    }
    if(len[fa[np]]<n) ans=(ans+(f[min(n,len[np])]-f[len[fa[np]]]+mod)%mod)%mod;
}

void init() {
    for(int i=1; i<=tot; ++i) {
        len[i]=fa[i]=0;
        memset(ch[i],0,104);
    }
    last=tot=1; ans=0;
}

int main() {
    //ios::sync_with_stdio(false); cin.tie(0);
    for(int i=1; i<maxn; ++i) {
        f0[i][0]=1;
        for(int j=1; j<=50; ++j) f0[i][j]=f0[i][j-1]*i%mod;
    }
    ff[0]=1;
    for(int i=1; i<maxn; ++i) ff[i]=ff[i-1]*729486258%mod;
    int T=read();
    while(T--) {
        l=read(), k=read(), n=read(), m=read();
        scanf("%s", s);
        for(int i=0; i<=k; ++i) a[i]=read();
        for(int i=1; i<=l+m; ++i) {
            f[i]=0;
            for(int j=0; j<=k; ++j) f[i]=(f[i]+a[j]*f0[i][j])%mod;
        }
        for(int i=1; i<=l+m; ++i) f[i]=f[i]*ff[i]%mod*(n-i+1)%mod;
        for(int i=1; i<=l+m; ++i) f[i]=(f[i]+f[i-1])%mod;
        for(int i=0; s[i]; ++i) add(s[i]-'a');
        printf("%lld\n", ans);
        while(m--) {
            char t[3]; scanf("%s", t);
            add(t[0]-'a');
            printf("%lld\n", ans);
        }
        init();
    }
}