牛客练习赛71 E- 神奇的迷宫
题意
给一颗个点的树,每条边的长度均为,Alice和Bob两人依次传送到树的某两个结点。对于任意一个人,传送到点的概率为,假设两人传送到的结点之间的最短距离为,那么他们挑战这个树的困难度为。
问他们挑战这个树的困难度的期望是多少。
分析
令表示两人最短距离为的概率,答案即为。
求可以用点分治来做,以作为分治中心时,枚举每个子树,用表示已经枚举过的子树中到根的距离为的点的概率之和,用表示当前子树中到根的距离为的点的概率之和,那么就可以更新,注意到这是一个卷积形式,所以我们对做一次卷积就能更新,因为答案要取模,所以用来做卷积。
复杂度为。
Code
#include <bits/stdc++.h> using namespace std; typedef long long ll; const int N=1e6+10; const int mod = 998244353, G = 3, Gi = 332748118; int n; ll p[N],w[N]; vector<int>g[N]; int sz[N],vis[N],mx[N],rt,tot,k1,k2; ll ans,A[N],B[N]; int limit = 1, L, r[N]; ll a[N], b[N]; ll ksm(ll a, ll b) { ll ret = 1; while(b) { if(b & 1) ret = (ret * a ) % mod; a = (a * a) % mod; b >>= 1; } return ret; } void NTT(ll *A, int type) { for(int i = 0; i < limit; i++) if(i < r[i]) swap(A[i], A[r[i]]); for(int mid = 1; mid < limit; mid <<= 1) { ll Wn = ksm( type == 1 ? G : Gi , (mod - 1) / (mid << 1)); for(int j = 0; j < limit; j += (mid << 1)) { ll w = 1; for(int k = 0; k < mid; k++, w = (w * Wn) % mod) { int x = A[j + k], y = w * A[j + k + mid] % mod; A[j + k] = (x + y) % mod, A[j + k + mid] = (x - y + mod) % mod; } } } } void gao() { limit=1,L=0; for(int i=0;i<=k1;i++) a[i]=A[i]; for(int i=0;i<=k2;i++) b[i]=B[i]; while(limit <= k1 + k2) limit <<= 1, L++; for(int i = 0; i < limit; i++) r[i] = (r[i >> 1] >> 1) | ((i & 1) << (L - 1)); NTT(a, 1);NTT(b, 1); for(int i = 0; i < limit; i++) a[i] = (a[i] * b[i]) % mod; NTT(a, -1); ll inv = ksm(limit, mod - 2); for(int i=0;i<=k1+k2;i++) a[i]=(a[i]*inv)%mod; for(int i = 0; i <= k1 + k2&&i<n; i++){ ans=(ans+a[i] * w[i]%mod*2%mod)%mod; } for(int i=0;i<=limit;i++) a[i]=b[i]=r[i]=0; } void getrt(int u,int fa){ sz[u]=1,mx[u]=0; for(int x:g[u]){ if(x==fa||vis[x]) continue; getrt(x,u); sz[u]+=sz[x]; mx[u]=max(mx[u],sz[x]); } mx[u]=max(mx[u],tot-sz[u]); if(mx[u]<mx[rt]) rt=u; } void dfs(int u,int fa,int d){ B[d]=(B[d]+p[u])%mod; k2=max(k2,d); for(int x:g[u]){ if(x==fa||vis[x]) continue; dfs(x,u,d+1); } } void solve(int u){ vis[u]=1;k1=k2=0; A[0]=p[u]; for(int x:g[u]){ if(vis[x]) continue; k2=0; dfs(x,u,1); /* for(int i=0;i<=k1;i++){ for(int j=0;j<=k2;j++) if(i+j<n){ ans+=w[i+j]*A[i]%mod*B[j]%mod*2%mod; ans%=mod; } } */ gao(); k1=max(k1,k2); for(int i=0;i<=k2;i++) A[i]=(A[i]+B[i])%mod,B[i]=0; } for(int i=0;i<=k1;i++) A[i]=0; for(int x:g[u]){ if(vis[x]) continue; tot=sz[x],mx[rt=0]=n; getrt(x,0); solve(rt); } } int main() { scanf("%d",&n); for(int i=1;i<=n;i++){ scanf("%lld",&p[i]); p[0]+=p[i]; if(p[0]>=mod) p[0]-=mod; } p[0]=ksm(p[0],mod-2); for(int i=1;i<=n;i++){ scanf("%lld",&w[i-1]); p[i]=p[i]*p[0]%mod; } for(int i=2,x,y;i<=n;i++){ scanf("%d%d",&x,&y); g[x].push_back(y); g[y].push_back(x); } tot=mx[rt]=n; getrt(1,0); solve(rt); for(int i=1;i<=n;i++) ans=(ans+w[0]*p[i]%mod*p[i]%mod)%mod; printf("%lld\n",ans); return 0; }