思路:
这道题首先是求关于路径的情况,我们需要直到用题目中已知所有的WL,为了求出最终的期望,我们应该求出每一个L的概率Pl,这是对某个路径长度的全部情况在树上求解,显然是用点分治处理。
用母函数的角度来考虑
f(x)=a0x^0+a1x^1+a2x^2+...+an-1x^n-1,a是概率
以u为当前根结点,f(x)中x^i是经过当前u且到u距离为i(i > 0)的点分布在子树中的概率,f(x)*f(x)可以得到经过u的两条长度为j,k,(j+k==i)的线路合并为长度为i的,也就是一条距离为i的路径在子树中的概率。按照点分治的思路,当前u为根结点,只算经过u的路径,所以去掉两端点在同一颗子树的情况,然后再补上一个点如果在u的情况。
最后答案再补上两个点都在一个位置的情况的概率。
通过这道题还发现自己对于vector的动态空间不了解,vector和数组还是有区别的,比如直接定义一个vector<int> f, f[0]是不存在的,resize后才有。
代码:</int>

#include<bits/stdc++.h>
using namespace std;
#define int long long
typedef long long ll;
const int mod = 998244353;
const int maxn = 1e5+7;
inline int read(){
    int x=0,f=1;char ch=getchar();
    while(!isdigit(ch)){if(ch=='-')f=-1;ch=getchar();}
    while(isdigit(ch)){x=(x<<1)+(x<<3)+(ch^48);ch=getchar();}
    return x*f;
}
int q_pow(int a, int b){
    ll ans = 1;
    while(b > 0){
        if(b & 1){
            ans = ans * a % mod;
        }
        a = 1ll * a * a % mod;
        b >>= 1; 
    } 
    return (int)ans;
}


int n;
namespace Poly
{
    const int N = 3e5+5;
    int m, lim, r[N], w[N], f[N], g[N];
    void prework()
    {
        for(m = 1; m <= 2 * n; m <<= 1, lim++); lim--;
        for(int rt, i = 1; i < m; i <<= 1)
        {
            rt = q_pow(3, (mod - 1) / (i << 1)), w[i] = 1;
            for(int j = 1; j < i; j++) w[i + j] = 1ll * w[i + j - 1] * rt % mod;
        }
    }
    void NTT(int *p, int op)
    {
        for(int i = 0; i < m; i++) if(i > r[i]) swap(p[i], p[r[i]]);
        for(int x, y, i = 1; i < m; i <<= 1)
            for(int j = 0; j < m; j += (i << 1))
                for(int k = 0; k < i; k++)
                {
                    x = p[j + k], y = 1ll * w[i + k] * p[i + j + k] % mod;
                    p[j + k] = (x + y) % mod, p[i + j + k] = (x - y + mod) % mod;
                }
        if(op == 1) return;
        reverse(p + 1, p + m);
        for(int inv = q_pow(m, mod - 2), i = 0; i < m; i++)
            p[i] = 1ll * p[i] * inv % mod;
    }
    vector<int> mul(vector<int> a, vector<int> b)
    {
        int sz = a.size() + b.size() - 1; lim = 0;
        for(m = 1; m <= sz; m <<= 1, lim++); lim--;
        for(int i = 1; i < m; i++) r[i] = (r[i >> 1] >> 1) | ((i & 1) << lim);
        for(int i = 0; i < m; i++) f[i] = g[i] = 0;
        for(int i = 0; i < a.size(); i++) f[i] = a[i];
        for(int i = 0; i < b.size(); i++) g[i] = b[i];
        NTT(f, 1), NTT(g, 1);
        for(int i = 0; i < m; i++) f[i] = 1ll * f[i] * g[i] % mod;
        NTT(f, -1); a.resize(sz);
        for(int i = 0; i < a.size(); i++) a[i] = f[i];
        return a;
    }
};

int p[maxn], w[maxn];

int head[maxn], tot;
struct Node{
    int to, nxt;
}e[maxn<<1];

void addedge(int u, int v){
    e[++tot].to = v;
    e[tot].nxt = head[u];
    head[u] = tot;
}

int maxp[maxn], siz[maxn], sum, rt;
bool vis[maxn];
void getrt(int u, int fa)
{
    siz[u]=1, maxp[u]=0;
    for(int i = head[u]; ~i; i = e[i].nxt)
    {
        int v = e[i].to;
        if(v == fa || vis[v]) continue;
        getrt(v, u);
        siz[u] += siz[v];
        if(siz[v] > maxp[u]) maxp[u] = siz[v];
    }
    maxp[u] = max(maxp[u], sum-siz[u]);
    if(maxp[u] < maxp[rt]) rt = u;
}

vector<int> son[maxn];
int cnt[maxn];
vector<int> f[maxn];

void dfs(int x, int u, int fa, int dep) {
    if(dep >= son[x].size()) {
        son[x].push_back(p[u]);
    }
    else {
        son[x][dep] = (son[x][dep] + p[u]) % mod;
    }
    for(int i = head[u]; ~i; i = e[i].nxt){
        int v = e[i].to;
        if(v == fa || vis[v]) continue;
        dfs(x, v, u, dep+1);
    }
}

void solve(int u)
{
    vector<int> res;
    int sz;
    for(int i = head[u]; ~i; i = e[i].nxt)
    {
        int v = e[i].to;
        if(vis[v]) continue;
        son[v].clear();
        son[v].push_back(0);
        dfs(v, v, u, 1);
        res = Poly::mul(son[v], son[v]);
        sz = res.size();
        for(int j = 0; j < sz; j++) cnt[j] = (cnt[j] - res[j] + mod) % mod;
        sz = son[v].size();
        if(sz > f[u].size()) f[u].resize(sz);
        for(int j = 0; j < sz; j++) f[u][j] = (f[u][j] + son[v][j]) % mod;
    }
//    f[u][0] = (1ll*f[u][0] + p[u] + mod) % mod;
//    res = Poly::mul(f[u], f[u]);
//    for(int j = 0; j < sz; j++) cnt[j] = (1ll*cnt[j] + res[j] + mod) % mod;
    if(f[u].size() == 0) return;
    res = Poly::mul(f[u], f[u]);
    for(int j = 0; j < res.size(); j++) cnt[j] = (cnt[j] + res[j]) % mod;
    for(int j = 0; j < f[u].size(); j++) f[u][j] = 2ll*f[u][j] * p[u] % mod;
    for(int j = 0; j < f[u].size(); j++) cnt[j] = (cnt[j] + f[u][j]) % mod;
}

void divide(int u)
{
    vis[u] = 1;
    solve(u);
    for(int i = head[u]; ~i; i = e[i].nxt)
    {
        int v = e[i].to;
        if(vis[v]) continue;
        maxp[rt=0]=sum=siz[v];
        getrt(v, 0);
        getrt(rt, 0);
        divide(rt);
    }
}
signed main()
{
    scanf("%lld", &n);
    memset(head, -1, sizeof(head)); Poly::prework();
    ll s = 0;
    for(int i = 1; i <= n; i++) p[i] = read(), s = (s + p[i]) % mod;
    s = q_pow(s, mod - 2);
//    cout<<s<<'\n';
    for(int i = 1; i <= n; i++) p[i] = 1ll * p[i] * s % mod;
    for(int i = 0; i <= n - 1; i++) w[i] = read();
    for(int i = 1; i <= n - 1; i++) {
        int u, v;
        u = read(); v = read();
        addedge(u, v);
        addedge(v, u);
    }
    maxp[0] = sum = n;
    getrt(1, 0);
    getrt(rt, 0);
    divide(rt);
    ll ans = 0;
    for(int i = 1; i <= n; i++) {
        cnt[0] = (cnt[0] + p[i] * p[i] % mod) % mod;
    }
//    cout<<cnt[0]<<'\n';
    for(int i = 0; i <= n - 1; i++) {
        ans = (ans + w[i] * cnt[i] % mod ) % mod;
    }
    cout << ans << endl;
    return 0;
}