找相同字符

题意:给定两个字符串,在两个字符串中任选子串,求两子串相同的方案数

思路:

  1. 用后缀自动机处理两个字符串,一般考虑在中间插入一个用不到的字符(比如’#’),然后将它们组成一个串(后来发现好像不能插’#’,毕竟我数组第二维只开了26的大小;然后有一个方法就是把数组开成27的,然后插入一个数字26就行了)
  2. 分别记录每个节点有多少个属于第一个字符串和第二个字符串的endpos,最后累加答案即可
  3. 注意爆int,叠加答案采用 1 L L c n t [ i ] [ 0 ] c n t [ i ] [ 1 ] ( l e n [ i ] l e n [ f a [ i ] ] ) 1LL*cnt[i][0]*cnt[i][1]*(len[i]-len[fa[i]]) 1LLcnt[i][0]cnt[i][1](len[i]len[fa[i]])

题目描述

给定两个字符串,求出在两个字符串中各取出一个子串使得这两个子串相同的方案数。两个方案不同当且仅当这两个子串中有一个位置不同。

输入格式

两行,两个字符串s1,s2,长度分别为n1,n2。1 <=n1, n2<= 200000,字符串中只有小写字母

输出格式

输出一个整数表示答案

输入输出样例

输入
aabb
bbaa
输出
10

//#pragma comment(linker, "/STACK:102400000,102400000")
#include "bits/stdc++.h"
#define pb push_back
#define ls l,m,now<<1
#define rs m+1,r,now<<1|1
#define hhh printf("hhh\n")
#define see(x) (cerr<<(#x)<<'='<<(x)<<endl)
using namespace std;
typedef long long ll;
typedef pair<int,int> pr;
inline int read() {int x=0;char c=getchar();while(c<'0'||c>'9')c=getchar();while(c>='0'&&c<='9')x=x*10+c-'0',c=getchar();return x;}

const int maxn = 1e6+10;
const int mod = 1e9+7;
const double eps = 1e-9;

char s0[maxn], s1[maxn];
int ch[maxn][27], fa[maxn], len[maxn];
int last=1, sz=1;
int cnt[maxn][2];
int a[maxn], c[maxn];

void add(int c, int f) {
    int p=last, np=last=++sz;
    len[np]=len[p]+1, cnt[np][f]=1;
    for(; p&&!ch[p][c]; p=fa[p]) ch[p][c]=np;
    if(!p) fa[np]=1;
    else {
        int q=ch[p][c];
        if(len[q]==len[p]+1) fa[np]=q;
        else {
            int nq=++sz;
            fa[nq]=fa[q], len[nq]=len[p]+1;
            memcpy(ch[nq],ch[q],108);
            fa[q]=fa[np]=nq;
            for(; p&&ch[p][c]==q; p=fa[p]) ch[p][c]=nq;
        }
    }
}

int main() {
    //ios::sync_with_stdio(false);
    scanf("%s%s", s0, s1);
    int len0=strlen(s0), len1=strlen(s1);
    for(int i=0; i<len0; ++i) add(s0[i]-'a',0);
    add(26,0);
    for(int i=0; i<len1; ++i) add(s1[i]-'a',1);

    for(int i=1; i<=sz; ++i) c[len[i]]++;
    for(int i=1; i<=sz; ++i) c[i]+=c[i-1];
    for(int i=1; i<=sz; ++i) a[c[len[i]]--]=i;

    for(int i=sz; i; --i) cnt[fa[a[i]]][0]+=cnt[a[i]][0],
                          cnt[fa[a[i]]][1]+=cnt[a[i]][1];
    ll ans=0;
    for(int i=1; i<=sz; ++i) ans+=1LL*cnt[i][0]*cnt[i][1]*(len[i]-len[fa[i]]);
    cout<<ans<<endl;
}