题意:
给出两个01串A,b,记 ai a i 表示A中1的出现位置, bi b i 表示B中1的出现位置,将a数组和b数组打乱后依次次交换 Aai A a i 和 Abi A b i ,求有几种方式使得A=B
字符串长度<=10000
Solution:
我们可以把答案拆分成两步:
1.枚举a和b的匹配
2.打乱匹配顺序
假设我们已经完成了操作1,我们来计算每个匹配所能产生的期望合法方案
尝试转化一下模型:对于一个给定的匹配,我们从 ai a i 向 bi b i 连一条有向边,可以发现这个图最终由若干个环和若干条链构成,且链的顺序是唯一的
假设有e个 Ai=Bi=1 A i = B i = 1 ,m个 Ai=1,Bi=0 A i = 1 , B i = 0 ,可以发现边数为e+m,图由m条链和若干环组成
考虑将e个点分配到m条链中, f[i][j] f [ i ] [ j ] 表示前i条链分到j个点的期望合法方案
那么有转移: f[i][j]=∑u≤ju=0f[i−1][j−u](u+1)! f [ i ] [ j ] = ∑ u = 0 u ≤ j f [ i − 1 ] [ j − u ] ( u + 1 ) ! (为什么要除(u+1)!呢?因为一共加入了u+1条边,这些边有(u+1)!种匹配方式,而在这些匹配方式中只有一种是合法的)
最终的答案即为 e!∗m!∗(e+m)!∗∑j≤ej=0f[m][j] e ! ∗ m ! ∗ ( e + m ) ! ∗ ∑ j = 0 j ≤ e f [ m ] [ j ]
e!表示点的总分配方式,m!表示链的不同排序数,(e+m)!表示总匹配数
朴素做法 O(n3) O ( n 3 ) ,可以用NTT+快速幂优化到 O(nlog2n) O ( n log 2 n )
O(n3) O ( n 3 ) 代码:
#include<cstdio>
#include<iostream>
#include<cstring>
using namespace std;
char a[100010],b[100010];
int n,num,f[510][510],tot;
const int mod=998244353;
int jc[100010],inv[100010],ans;
int fast_pow(int a,int x)
{
int ans=1;
for (;x;x>>=1,a=1ll*a*a%mod)
if (x&1) ans=1ll*ans*a%mod;
return ans;
}
int main()
{
scanf("%s%s",a+1,b+1);
n=strlen(a+1);
for (int i=1;i<=n;i++)
{
if (a[i]=='1'&&b[i]=='1') num++;
if (a[i]=='1') tot++;
}
jc[0]=1;
for (int i=1;i<=tot+1;i++) jc[i]=1ll*jc[i-1]*i%mod;
inv[tot+1]=fast_pow(jc[tot+1],mod-2);
for (int i=tot;i>=0;i--) inv[i]=1ll*inv[i+1]*(i+1)%mod;
f[0][0]=1;
for (int i=1;i<=tot-num;i++)
{
for (int j=0;j<=num;j++)
for (int k=0;k<=j;k++)
f[i][j]=(1ll*f[i-1][j-k]*inv[k+1]+f[i][j])%mod;
for (int j=0;j<=num;j++) printf("%d ",f[i][j]);cout<<endl;
}
for (int i=0;i<=num;i++)
ans=(ans+f[tot-num][i])%mod;
printf("%d",1ll*ans*jc[tot-num]%mod*jc[num]%mod*jc[tot]%mod);
}
O(nlog2n) O ( n log 2 n ) 代码:
#include<cstdio>
#include<iostream>
#include<cstring>
#include<cmath>
using namespace std;
char a[100010],b[100010];
int n,num,tot;
const int mod=998244353;
int jc[100010],inv[100010],ans;
const int G=3;
int x1[100010],x2[100010];
int fast_pow(int a,int x)
{
int ans=1;
for (;x;x>>=1,a=1ll*a*a%mod)
if (x&1) ans=1ll*ans*a%mod;
return ans;
}
void change(int y[],int len)
{
int i,j,k;
for (i=1,j=len/2;i<len-1;i++)
{
if (i<j) swap(y[i],y[j]);
k=len/2;
while (j>=k) j-=k,k>>=1;
if (j<k) j+=k;
}
return;
}
void fft(int y[],int len,int ifi)
{
change(y,len);
for (int h=2;h<=len;h<<=1)
{
int wn=fast_pow(G,(ifi==1)?(mod-1)/h:mod-1-(mod-1)/h);
for (int j=0;j<len;j+=h)
{
int w=1;
for (int k=j;k<j+h/2;k++)
{
int u=y[k];
int t=1ll*w*y[k+h/2]%mod;
y[k]=(u+t)%mod;
y[k+h/2]=(1ll*u-t+mod)%mod;
w=1ll*w*wn%mod;
}
}
}
if (ifi==-1)
{
int iv=fast_pow(len,mod-2);
for (int i=0;i<len;i++) y[i]=1ll*y[i]*iv%mod;
}
}
void add(int len)
{
fft(x2,len,1);
for (int i=0;i<len;i++) x2[i]=1ll*x2[i]*x2[i]%mod;
fft(x2,len,-1);
for (int i=num+1;i<len;i++) x2[i]=0;
}
int main()
{
scanf("%s%s",a+1,b+1);
n=strlen(a+1);
for (int i=1;i<=n;i++)
{
if (a[i]=='1'&&b[i]=='1') num++;
if (a[i]=='1') tot++;
}
jc[0]=1;
for (int i=1;i<=tot+1;i++) jc[i]=1ll*jc[i-1]*i%mod;
inv[tot+1]=fast_pow(jc[tot+1],mod-2);
for (int i=tot;i>=0;i--) inv[i]=1ll*inv[i+1]*(i+1)%mod;
for (int i=0;i<=tot;i++) inv[i]=inv[i+1];
int len=1;
while (len<=2*num) len<<=1;
x1[0]=1;
for (int i=1;i<len;i++) x1[i]=0;
for (int i=0;i<=num;i++) x2[i]=inv[i];
for (int i=num+1;i<len;i++) x2[i]=0;
for (int i=tot-num;i;i>>=1,add(len))
if (i&1)
{
fft(x2,len,1);fft(x1,len,1);
for (int j=0;j<len;j++) x1[j]=1ll*x1[j]*x2[j]%mod;
fft(x1,len,-1);fft(x2,len,-1);
for (int j=num+1;j<len;j++) x1[j]=0;
}
int ans=0;
for (int i=0;i<=num;i++) ans=(ans+x1[i])%mod;
printf("%d",1ll*ans*jc[tot-num]%mod*jc[num]%mod*jc[tot]%mod);
}