题意

给出一个长串 t t t ,n个短串 s i s_i si ,问两两连接成的 s i + s j s_i+s_j si+sj t t t 的出现次数的和为多少

题解

从另一个角度分析,每次的匹配,在 t t t 中的 s i s_i si s j s_j sj 的连接位置 +1, 我们枚举 t t t 的每一个对应连接的位置,看有多少对能匹配上

若在位置 k k k 能匹配上,有一个特征, s i = t [ i l e n ( s i ) + 1 , k ] , s j = t [ k + 1 , i + l e n ( s j ) 1 ] s_i=t[i-len(s_i)+1,k],s_j=t[k+1,i+len(s_j)-1] si=t[ilen(si)+1,k],sj=t[k+1,i+len(sj)1]
就是说, t t t 的后缀等于 s i s_i si t t t 的前缀等于 s j s_j sj
把 n 个串插到 AC 自动机,用 t 匹配就好了,然后所有串翻转,再重复一遍
注意
由于同一种串在不同位置算多次,不能暴力跳 f a i l fail fail , 有一个小优化,就是计算 c n t cnt cnt 的时候从 f a i l fail fail 继承下来

代码

#include<bits/stdc++.h>
#define N 200010
#define INF 0x3f3f3f3f
#define eps 1e-6
#define pi acos(-1.0)
#define mod 998244353
#define P 1000000007
#define LL long long
#define pb push_back
#define fi first
#define se second
#define cl clear
#define si size
#define lb lower_bound
#define ub upper_bound
#define bug(x) cerr<<#x<<" : "<<x<<endl
#define mem(x) memset(x,0,sizeof x)
#define sc(x) scanf("%d",&x)
#define scc(x,y) scanf("%d%d",&x,&y) 
#define sccc(x,y,z) scanf("%d%d%d",&x,&y,&z)
using namespace std;
struct Trie{
    int t[N][26],fail[N],d[N],w[N];
    int tot;
    void init(){ tot=0;
        memset(fail,0,sizeof fail);
        memset(t,0,sizeof t);
        memset(d,0,sizeof d);
    }
    void ins(char* x){
        int i=0;
        for(int j=0;x[j];j++){
            int c=x[j]-'a';
            if(!t[i][c]) t[i][c]=++tot;
            i=t[i][c];
        }
        d[i]++;
    }
    void get_fail(){
        queue<int> q;
        for (int i=0;i<26;i++) if (t[0][i]) q.push(t[0][i]);
        while(!q.empty()) {
            int i=q.front();q.pop();
            for(int j=0;j<26;j++) {
                if (t[i][j]){
                    fail[t[i][j]]=t[fail[i]][j];
                    d[t[i][j]]+=d[fail[t[i][j]]];
                    q.push(t[i][j]);
                }else
                t[i][j]=t[fail[i]][j];
            } 
        }
    }
    void AC_automation(char *x){
        int i=0;
        for(int j=0;x[j];j++) {
            i=t[i][x[j]-'a'];
            w[j]=d[i];
        }
    }
}T1,T2;


char t[N],s[N];

int main(){
    scanf("%s",t);
    int n;
    sc(n);
    while(n--){
        scanf("%s",s);
        T1.ins(s);
        int l=strlen(s);
        reverse(s,s+l);
        T2.ins(s);
    }
    T1.get_fail(); T2.get_fail();

    T1.AC_automation(t);
    
    int l=strlen(t); reverse(t,t+l);
    
    T2.AC_automation(t);
    LL ans=0;
    for(int i=0;i<l-1;i++) ans+=1ll*T1.w[i]*T2.w[l-i-2];
    cout<<ans;
}