题解:

题目的意思就是在第一个串里找“s1s2s3”,第二个串里找“s4”,如上拼接后,是一个回文串,求方案数

可以发现,s1与s4是回文的,s2和s3是回文的,我们枚举s1的右端点,s1的长度乘以s2起始点为左边界的回文串的数量,累加就是答案。

现在分两部分,一是求s1,二是求以每个点为左边界的回文串的数量

一的话,就是求每个后缀匹配第二个串的LCP,可以用扩展kmp求得,也可以用hash加二分求得,二的话,用马拉车算法+前缀和就可以解决。

代码:

#include<bits/stdc++.h>
#define N 1000010
#define INF 0x3f3f3f3f
#define eps 1e-10
#define pi 3.141592653589793
#define mod 998244353
#define LL long long
#define pb push_back
#define cl clear
#define si size
#define lb lower_bound
#define ub upper_bound
#define mem(x) memset(x,0,sizeof x)
#define sc(x) scanf("%d",&x)
#define scc(x,y) scanf("%d%d",&x,&y)
#define sccc(x,y,z) scanf("%d%d%d",&x,&y,&z)
using namespace std;

int lens,lent;
char s[N<<1],t[N],ss[N<<1];
int d[N],r[N<<1];

int Init()
{
    int len = strlen(s+1);
    ss[0] = '$';ss[1] = '#';
    int j = 2;
    for (int i = 1; i <= len; i++)ss[j++] = s[i],ss[j++] = '#';
    ss[j] = '\0';
    return j;
}

void Manacher()
{
    int len=Init();
    int p,mx=0;
    for (int i = 1; i < len; i++)
    {
        if (i<mx) r[i]=min(r[2*p-i],mx-i);else r[i] = 1;
        while (ss[i-r[i]]==ss[i+r[i]]) r[i]++;
        if (mx<i+r[i])p=i,mx=i+r[i];
    }
    for (int i=2;i<len;i++)
    {
        if (ss[i]=='#' && r[i]==1) continue;
        int x=i/2-r[i]/2+1,y=i/2+r[i]/2-!(i&1);
        d[x]++;d[(x+y)/2+1]--;
    }
}

LL p1[N],p2[N],h1[N],h2[N],h3[N],h4[N];

const LL m1=998244353;
const LL m2=100000007;

LL spy(int x,int y)
{
    LL t1=(h1[lens-x+1]-(LL)h1[lens-y]*p1[y-x+1]%m1+m1)%m1;
    LL t2=(h2[lens-x+1]-(LL)h2[lens-y]*p2[y-x+1]%m2+m2)%m2;
    return t1<<31|t2;
}

LL spyer(int y)
{
    return h3[y]<<31|h4[y];
}

int main()
{
    p1[0]=p2[0]=1;
    for (int i=1;i<N;i++) p1[i]=p1[i-1]*377%m1,p2[i]=p2[i-1]*377%m2;

    while(~scanf("%s%s",s+1,t+1))
    {
        mem(d);
        lens=strlen(s+1); lent=strlen(t+1);
        for (int i=1;i<=lens;i++) d[i]=0;
        Manacher();
        for (int i=1;i<=lens;i++) d[i]+=d[i-1];
        strcpy(ss+1,s+1);
        reverse(s+1,s+lens+1);
        for (int i=1;i<=lens;i++)
            h1[i]=(h1[i-1]*377+s[i])%m1,
            h2[i]=(h2[i-1]*377+s[i])%m2;
        for (int i=1;i<=lent;i++)
            h3[i]=(h3[i-1]*377+t[i])%m1,
            h4[i]=(h4[i-1]*377+t[i])%m2;
        LL ans=0;
        for (int i=1;i<lens;i++)
        {
            if (ss[i]!=t[1]) continue;
            int l=1,r=i>lent?lent:i;
            LL k=1;
            while(l<=r)
            {
                int m=l+r>>1;
                if (spy(i-m+1,i)==spyer(m))
                    k=m,l=m+1;else r=m-1;
            }
            ans+=k*d[i+1];
        }
        printf("%lld\n",ans);
    }
}