思路:
这道题首先是求关于路径的情况,我们需要直到用题目中已知所有的WL,为了求出最终的期望,我们应该求出每一个L的概率Pl,这是对某个路径长度的全部情况在树上求解,显然是用点分治处理。
用母函数的角度来考虑
f(x)=a0x^0+a1x^1+a2x^2+...+an-1x^n-1,a是概率
以u为当前根结点,f(x)中x^i是经过当前u且到u距离为i(i > 0)的点分布在子树中的概率,f(x)*f(x)可以得到经过u的两条长度为j,k,(j+k==i)的线路合并为长度为i的,也就是一条距离为i的路径在子树中的概率。按照点分治的思路,当前u为根结点,只算经过u的路径,所以去掉两端点在同一颗子树的情况,然后再补上一个点如果在u的情况。
最后答案再补上两个点都在一个位置的情况的概率。
通过这道题还发现自己对于vector的动态空间不了解,vector和数组还是有区别的,比如直接定义一个vector<int> f, f[0]是不存在的,resize后才有。
代码:</int>
#include<bits/stdc++.h> using namespace std; #define int long long typedef long long ll; const int mod = 998244353; const int maxn = 1e5+7; inline int read(){ int x=0,f=1;char ch=getchar(); while(!isdigit(ch)){if(ch=='-')f=-1;ch=getchar();} while(isdigit(ch)){x=(x<<1)+(x<<3)+(ch^48);ch=getchar();} return x*f; } int q_pow(int a, int b){ ll ans = 1; while(b > 0){ if(b & 1){ ans = ans * a % mod; } a = 1ll * a * a % mod; b >>= 1; } return (int)ans; } int n; namespace Poly { const int N = 3e5+5; int m, lim, r[N], w[N], f[N], g[N]; void prework() { for(m = 1; m <= 2 * n; m <<= 1, lim++); lim--; for(int rt, i = 1; i < m; i <<= 1) { rt = q_pow(3, (mod - 1) / (i << 1)), w[i] = 1; for(int j = 1; j < i; j++) w[i + j] = 1ll * w[i + j - 1] * rt % mod; } } void NTT(int *p, int op) { for(int i = 0; i < m; i++) if(i > r[i]) swap(p[i], p[r[i]]); for(int x, y, i = 1; i < m; i <<= 1) for(int j = 0; j < m; j += (i << 1)) for(int k = 0; k < i; k++) { x = p[j + k], y = 1ll * w[i + k] * p[i + j + k] % mod; p[j + k] = (x + y) % mod, p[i + j + k] = (x - y + mod) % mod; } if(op == 1) return; reverse(p + 1, p + m); for(int inv = q_pow(m, mod - 2), i = 0; i < m; i++) p[i] = 1ll * p[i] * inv % mod; } vector<int> mul(vector<int> a, vector<int> b) { int sz = a.size() + b.size() - 1; lim = 0; for(m = 1; m <= sz; m <<= 1, lim++); lim--; for(int i = 1; i < m; i++) r[i] = (r[i >> 1] >> 1) | ((i & 1) << lim); for(int i = 0; i < m; i++) f[i] = g[i] = 0; for(int i = 0; i < a.size(); i++) f[i] = a[i]; for(int i = 0; i < b.size(); i++) g[i] = b[i]; NTT(f, 1), NTT(g, 1); for(int i = 0; i < m; i++) f[i] = 1ll * f[i] * g[i] % mod; NTT(f, -1); a.resize(sz); for(int i = 0; i < a.size(); i++) a[i] = f[i]; return a; } }; int p[maxn], w[maxn]; int head[maxn], tot; struct Node{ int to, nxt; }e[maxn<<1]; void addedge(int u, int v){ e[++tot].to = v; e[tot].nxt = head[u]; head[u] = tot; } int maxp[maxn], siz[maxn], sum, rt; bool vis[maxn]; void getrt(int u, int fa) { siz[u]=1, maxp[u]=0; for(int i = head[u]; ~i; i = e[i].nxt) { int v = e[i].to; if(v == fa || vis[v]) continue; getrt(v, u); siz[u] += siz[v]; if(siz[v] > maxp[u]) maxp[u] = siz[v]; } maxp[u] = max(maxp[u], sum-siz[u]); if(maxp[u] < maxp[rt]) rt = u; } vector<int> son[maxn]; int cnt[maxn]; vector<int> f[maxn]; void dfs(int x, int u, int fa, int dep) { if(dep >= son[x].size()) { son[x].push_back(p[u]); } else { son[x][dep] = (son[x][dep] + p[u]) % mod; } for(int i = head[u]; ~i; i = e[i].nxt){ int v = e[i].to; if(v == fa || vis[v]) continue; dfs(x, v, u, dep+1); } } void solve(int u) { vector<int> res; int sz; for(int i = head[u]; ~i; i = e[i].nxt) { int v = e[i].to; if(vis[v]) continue; son[v].clear(); son[v].push_back(0); dfs(v, v, u, 1); res = Poly::mul(son[v], son[v]); sz = res.size(); for(int j = 0; j < sz; j++) cnt[j] = (cnt[j] - res[j] + mod) % mod; sz = son[v].size(); if(sz > f[u].size()) f[u].resize(sz); for(int j = 0; j < sz; j++) f[u][j] = (f[u][j] + son[v][j]) % mod; } // f[u][0] = (1ll*f[u][0] + p[u] + mod) % mod; // res = Poly::mul(f[u], f[u]); // for(int j = 0; j < sz; j++) cnt[j] = (1ll*cnt[j] + res[j] + mod) % mod; if(f[u].size() == 0) return; res = Poly::mul(f[u], f[u]); for(int j = 0; j < res.size(); j++) cnt[j] = (cnt[j] + res[j]) % mod; for(int j = 0; j < f[u].size(); j++) f[u][j] = 2ll*f[u][j] * p[u] % mod; for(int j = 0; j < f[u].size(); j++) cnt[j] = (cnt[j] + f[u][j]) % mod; } void divide(int u) { vis[u] = 1; solve(u); for(int i = head[u]; ~i; i = e[i].nxt) { int v = e[i].to; if(vis[v]) continue; maxp[rt=0]=sum=siz[v]; getrt(v, 0); getrt(rt, 0); divide(rt); } } signed main() { scanf("%lld", &n); memset(head, -1, sizeof(head)); Poly::prework(); ll s = 0; for(int i = 1; i <= n; i++) p[i] = read(), s = (s + p[i]) % mod; s = q_pow(s, mod - 2); // cout<<s<<'\n'; for(int i = 1; i <= n; i++) p[i] = 1ll * p[i] * s % mod; for(int i = 0; i <= n - 1; i++) w[i] = read(); for(int i = 1; i <= n - 1; i++) { int u, v; u = read(); v = read(); addedge(u, v); addedge(v, u); } maxp[0] = sum = n; getrt(1, 0); getrt(rt, 0); divide(rt); ll ans = 0; for(int i = 1; i <= n; i++) { cnt[0] = (cnt[0] + p[i] * p[i] % mod) % mod; } // cout<<cnt[0]<<'\n'; for(int i = 0; i <= n - 1; i++) { ans = (ans + w[i] * cnt[i] % mod ) % mod; } cout << ans << endl; return 0; }