传送门

题意:

给出两个01串A,b,记 ai a i 表示A中1的出现位置, bi b i 表示B中1的出现位置,将a数组和b数组打乱后依次次交换 Aai A a i Abi A b i ,求有几种方式使得A=B

字符串长度<=10000

Solution:

我们可以把答案拆分成两步:

1.枚举a和b的匹配

2.打乱匹配顺序

假设我们已经完成了操作1,我们来计算每个匹配所能产生的期望合法方案

尝试转化一下模型:对于一个给定的匹配,我们从 ai a i bi b i 连一条有向边,可以发现这个图最终由若干个环和若干条链构成,且链的顺序是唯一的

假设有e个 Ai=Bi=1 A i = B i = 1 ,m个 Ai=1,Bi=0 A i = 1 , B i = 0 ,可以发现边数为e+m,图由m条链和若干环组成

考虑将e个点分配到m条链中, f[i][j] f [ i ] [ j ] 表示前i条链分到j个点的期望合法方案

那么有转移: f[i][j]=uju=0f[i1][ju](u+1)! f [ i ] [ j ] = ∑ u = 0 u ≤ j f [ i − 1 ] [ j − u ] ( u + 1 ) ! (为什么要除(u+1)!呢?因为一共加入了u+1条边,这些边有(u+1)!种匹配方式,而在这些匹配方式中只有一种是合法的)

最终的答案即为 e!m!(e+m)!jej=0f[m][j] e ! ∗ m ! ∗ ( e + m ) ! ∗ ∑ j = 0 j ≤ e f [ m ] [ j ]

e!表示点的总分配方式,m!表示链的不同排序数,(e+m)!表示总匹配数

朴素做法 O(n3) O ( n 3 ) ,可以用NTT+快速幂优化到 O(nlog2n) O ( n log 2 ⁡ n )

O(n3) O ( n 3 ) 代码:

#include<cstdio>
#include<iostream>
#include<cstring>
using namespace std;
char a[100010],b[100010];
int n,num,f[510][510],tot;
const int mod=998244353;
int jc[100010],inv[100010],ans;
int fast_pow(int a,int x)
{
    int ans=1;
    for (;x;x>>=1,a=1ll*a*a%mod)
        if (x&1) ans=1ll*ans*a%mod;
    return ans;
}
int main()
{
    scanf("%s%s",a+1,b+1);
    n=strlen(a+1);
    for (int i=1;i<=n;i++)
    {
        if (a[i]=='1'&&b[i]=='1') num++;
        if (a[i]=='1') tot++;
    }
    jc[0]=1;
    for (int i=1;i<=tot+1;i++) jc[i]=1ll*jc[i-1]*i%mod;
    inv[tot+1]=fast_pow(jc[tot+1],mod-2);
    for (int i=tot;i>=0;i--) inv[i]=1ll*inv[i+1]*(i+1)%mod; 
    f[0][0]=1;
    for (int i=1;i<=tot-num;i++)
    {
        for (int j=0;j<=num;j++)
            for (int k=0;k<=j;k++)
                f[i][j]=(1ll*f[i-1][j-k]*inv[k+1]+f[i][j])%mod; 
        for (int j=0;j<=num;j++) printf("%d ",f[i][j]);cout<<endl;
    }
    for (int i=0;i<=num;i++)
        ans=(ans+f[tot-num][i])%mod;
    printf("%d",1ll*ans*jc[tot-num]%mod*jc[num]%mod*jc[tot]%mod);
}

O(nlog2n) O ( n log 2 ⁡ n ) 代码:

#include<cstdio>
#include<iostream>
#include<cstring>
#include<cmath>
using namespace std;
char a[100010],b[100010];
int n,num,tot;
const int mod=998244353;
int jc[100010],inv[100010],ans;
const int G=3;
int x1[100010],x2[100010];
int fast_pow(int a,int x)
{
    int ans=1;
    for (;x;x>>=1,a=1ll*a*a%mod)
        if (x&1) ans=1ll*ans*a%mod;
    return ans;
}
void change(int y[],int len)
{
    int i,j,k;
    for (i=1,j=len/2;i<len-1;i++)
    {
        if (i<j) swap(y[i],y[j]);
        k=len/2;
        while (j>=k) j-=k,k>>=1;
        if (j<k) j+=k; 
    }
    return;
}
void fft(int y[],int len,int ifi)
{
    change(y,len);
    for (int h=2;h<=len;h<<=1)
    {
        int wn=fast_pow(G,(ifi==1)?(mod-1)/h:mod-1-(mod-1)/h);
        for (int j=0;j<len;j+=h)
        {
            int w=1;
            for (int k=j;k<j+h/2;k++)
            {
                int u=y[k];
                int t=1ll*w*y[k+h/2]%mod;
                y[k]=(u+t)%mod;
                y[k+h/2]=(1ll*u-t+mod)%mod;
                w=1ll*w*wn%mod;
            } 
        } 
    }
    if (ifi==-1)
    {
        int iv=fast_pow(len,mod-2);
        for (int i=0;i<len;i++) y[i]=1ll*y[i]*iv%mod;
    }
}
void add(int len)
{
    fft(x2,len,1);
    for (int i=0;i<len;i++) x2[i]=1ll*x2[i]*x2[i]%mod;
    fft(x2,len,-1);
    for (int i=num+1;i<len;i++) x2[i]=0;
}
int main()
{
    scanf("%s%s",a+1,b+1);
    n=strlen(a+1);
    for (int i=1;i<=n;i++)
    {
        if (a[i]=='1'&&b[i]=='1') num++;
        if (a[i]=='1') tot++;
    }
    jc[0]=1;
    for (int i=1;i<=tot+1;i++) jc[i]=1ll*jc[i-1]*i%mod;
    inv[tot+1]=fast_pow(jc[tot+1],mod-2);
    for (int i=tot;i>=0;i--) inv[i]=1ll*inv[i+1]*(i+1)%mod; 
    for (int i=0;i<=tot;i++) inv[i]=inv[i+1];

    int len=1;
    while (len<=2*num) len<<=1;
    x1[0]=1;
    for (int i=1;i<len;i++) x1[i]=0;
    for (int i=0;i<=num;i++) x2[i]=inv[i];
    for (int i=num+1;i<len;i++) x2[i]=0;

    for (int i=tot-num;i;i>>=1,add(len))
        if (i&1)
        {
            fft(x2,len,1);fft(x1,len,1);
            for (int j=0;j<len;j++) x1[j]=1ll*x1[j]*x2[j]%mod;
            fft(x1,len,-1);fft(x2,len,-1);
            for (int j=num+1;j<len;j++) x1[j]=0;
        }
    int ans=0;
    for (int i=0;i<=num;i++) ans=(ans+x1[i])%mod;
    printf("%d",1ll*ans*jc[tot-num]%mod*jc[num]%mod*jc[tot]%mod);
}