题意:
“简单无向图”是指无重边、无自环的无向图(不一定连通)。
一个带标号的图的价值定义为每个点度数的k次方的和。
给定n和k,请计算所有n个点的带标号的简单无向图的价值之和。
因为答案很大,请对998244353取模输出。
1≤n≤109,1≤k≤200000 1 ≤ n ≤ 10 9 , 1 ≤ k ≤ 200000
Solution:
易知每个点的贡献是一样的
所以我们只需考虑一个点的贡献:
枚举这个点的度数,剩下的边可以随意填
sum=2(n−2)∗(n−1)∗∑i≤n−1i=0Cin−1∗ik s u m = 2 ( n − 2 ) ∗ ( n − 1 ) ∗ ∑ i = 0 i ≤ n − 1 C n − 1 i ∗ i k
我们只需要考虑求 ∑i≤ni=0Cin∗ik ∑ i = 0 i ≤ n C n i ∗ i k
我们知道 nk=∑i≤ki=0Sik∗Cin∗i! n k = ∑ i = 0 i ≤ k S k i ∗ C n i ∗ i ! (S是第二类斯特林数)
那么
∑i≤ni=0Ckn∗ik ∑ i = 0 i ≤ n C n k ∗ i k
=∑i≤ni=0Cin∗∑j≤min(i,k)j=0Sjk∗Cji∗j! = ∑ i = 0 i ≤ n C n i ∗ ∑ j = 0 j ≤ m i n ( i , k ) S k j ∗ C i j ∗ j !
=∑i≤min(n,k)i=0Sik∗i!∗∑j≤nj=iCjn∗Cij = ∑ i = 0 i ≤ m i n ( n , k ) S k i ∗ i ! ∗ ∑ j = i j ≤ n C n j ∗ C j i
∑j≤nj=iCjn∗Cij ∑ j = i j ≤ n C n j ∗ C j i 的组合意义为在n个数里面选i个,剩下的随便选
所以 ∑j≤nj=iCjn∗Cij=Cin∗2n−i ∑ j = i j ≤ n C n j ∗ C j i = C n i ∗ 2 n − i
继续推原式可得
=∑i≤min(n,k)i=0Sik∗i!∗Cin∗2n−i = ∑ i = 0 i ≤ m i n ( n , k ) S k i ∗ i ! ∗ C n i ∗ 2 n − i
现在的问题就在于如何快速计算S
根据容斥原理,S是有一个公式的:
Smn=1m!∗∑mi=0(−1)i∗Cim∗(m−i)n S n m = 1 m ! ∗ ∑ i = 0 m ( − 1 ) i ∗ C m i ∗ ( m − i ) n
=∑mi=0(−1)ii!∗(m−i)n(m−i)! = ∑ i = 0 m ( − 1 ) i i ! ∗ ( m − i ) n ( m − i ) !
我们发现这是两个函数的卷积的形式,所以说可以使用fft来优化
组合数可以预处理,斯特林数可以NTT
所以总复杂度 O(klogk) O ( k log k )
代码:
#include<cstdio>
#include<cstring>
#include<iostream>
using namespace std;
const int mod=998244353;
const int g=3;
const int N=200010;
int mi[N],ni[N],n,k,nmi[N];
int a1[524290],a2[524290];
int fast_pow(long long x,int a)
{
int ans=1;
for (;x;x>>=1,a=1ll*a*a%mod)
if (x&1) ans=1ll*ans*a%mod;
return ans;
}
void change(int y[],int len)
{
int i,j,k;
for (i=1,j=len/2;i<len-1;i++)
{
if (i<j) swap(y[i],y[j]);
k=len/2;
while (j>=k) j-=k,k>>=1;
if (j<k) j+=k;
}
}
void fft(int y[],int len,int ifi)
{
change(y,len);
for (int h=1;h<=len;h<<=1)
{
int wn=(ifi==1)?fast_pow((mod-1)/h,g):fast_pow(mod-1-(mod-1)/h,g);
for (int i=0;i<len;i+=h)
{
int w=1;
for (int k=i;k<i+h/2;k++)
{
int u=y[k];
int v=1ll*w*y[k+h/2]%mod;
y[k]=(u+v)%mod;
y[k+h/2]=(1ll*u-v+mod)%mod;
w=1ll*w*wn%mod;
}
}
}
if (ifi==-1)
{
int ny=fast_pow(mod-2,len);
for (int i=0;i<len;i++) y[i]=1ll*y[i]*ny%mod;
}
}
int main()
{
scanf("%d%d",&n,&k);mi[0]=1;n--;
for (int i=1;i<=k;i++) mi[i]=1ll*mi[i-1]*i%mod;
nmi[0]=1;
for (int i=1;i<=min(n,k);i++) nmi[i]=1ll*nmi[i-1]*(n-i+1)%mod;
ni[k]=fast_pow(mod-2,mi[k]);
for (int i=k-1;i>=0;i--) ni[i]=1ll*ni[i+1]*(i+1)%mod;
int len=1;while (len<=2*k) len<<=1;
for (int i=0;i<=k;i++) a1[i]=(fast_pow(i,-1)*ni[i]+mod)%mod;
for (int i=0;i<=k;i++) a2[i]=1ll*fast_pow(k,i)*ni[i]%mod;
fft(a1,len,1);fft(a2,len,1);for (int i=0;i<len;i++) a1[i]=1ll*a1[i]*a2[i]%mod;
fft(a1,len,-1);
int ans=0;
for (int i=0;i<=min(n,k);i++)
ans=(1ll*ans+1ll*a1[i]*nmi[i]%mod*fast_pow(n-i,2))%mod;
ans=1ll*ans*(n+1)%mod*fast_pow(1ll*n*(n-1)/2,2)%mod;
printf("%d",ans);
}