题目可以简化为ans[i]表示距离为i的点对个数的概率和,求出这个概率和即可。
考虑使用点分支分解整棵树,然后在子树中选取深度小的进行启发式合并,这里合并用ntt进行加速。
复杂度O(nlognlogn)。

#include <cstdio>
#include <algorithm>
#include <cstring>
#include <vector>
using namespace std;

typedef long long ll;

const int N = 262144 + 100;
const int MOD = 998244353;

namespace NTT {
    #define pw(n) (1<<n)
    const int N = 262144, P = 998244353, g = 3;//或P=1004535809
    int n, m, bit, bitnum = 0, a[N + 5], b[N + 5], rev[N + 5];
    void getrev(int l) {
        for (int i = 0; i < pw(l); i++) {
            rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (l - 1));
        }
    }
    int fastpow(int a, int b) {
        int ans = 1;
        for (; b; b >>= 1, a = 1LL * a*a%P) {
            if (b & 1)ans = 1LL * ans*a%P;
        }
        return ans;
    }
    void NTT(int *s, int op) {
        for (int i = 0; i < bit; i++)if (i < rev[i])swap(s[i], s[rev[i]]);
        for (int i = 1; i < bit; i <<= 1) {
            int w = fastpow(g, (P - 1) / (i << 1));
            for (int p = i << 1, j = 0; j < bit; j += p) {
                int wk = 1;
                for (int k = j; k < i + j; k++, wk = 1LL * wk*w%P) {
                    int x = s[k], y = 1LL * s[k + i] * wk%P;
                    s[k] = (x + y) % P;
                    s[k + i] = (x - y + P) % P;
                }
            }
        }
        if (op == -1) {
            reverse(s + 1, s + bit);
            int inv = fastpow(bit, P - 2);
            for (int i = 0; i < bit; i++)s[i] = 1LL * s[i] * inv%P;
        }
    }
    int solve(int *aa, int nn, int *bb, int mm, int *c) {
        n = nn; m = mm;
        bit = bitnum = 0;
        for (int i = 0; i <= n; i++) a[i] = aa[i];
        for (int i = 0; i <= m; i++) b[i] = bb[i];
        m += n;
        for (bit = 1; bit <= m; bit <<= 1)bitnum++;
        getrev(bitnum);
        NTT(a, 1);
        NTT(b, 1);
        for (int i = 0; i < bit; i++) a[i] = 1LL * a[i] * b[i] % P;
        NTT(a, -1);
        for (int i = 0; i < bit; i++) c[i] = a[i];
        for (int i = 0; i < bit; i++) a[i] = b[i] = 0;
        return bit;
    }
}

ll qpow(ll x, ll n) {
    ll res = 1;
    while (n > 0) {
        if (n & 1) res = res * x % MOD;
        n /= 2;
        x = x * x % MOD;
    }
    return res;
}

int n, MX, R;
int sa[N], ww[N], siz[N], ms[N];
bool vis[N];
vector<int> V[N];

void getroot(int u, int fa) {
    siz[u] = 1; ms[u] = 0;
    for (int v : V[u]) {
        if (vis[v] || v == fa) continue;
        getroot(v, u);
        siz[u] += siz[v];
        ms[u] = max(ms[u], siz[v]);
    }
    ms[u] = max(ms[u], MX - siz[u]);
    if (ms[u] < ms[R]) R = u;
}

int dep[N], res[N], now[N], ss[N], md[N], ans[N];

void upd(int &a, int b) {
    a += b;
    if (a >= MOD) a -= MOD;
}

void dfs(int u, int fa) {
    md[u] = dep[u];
    siz[u] = 1;
    for (int v : V[u]) {
        if (vis[v] || v == fa) continue;
        dep[v] = dep[u] + 1;
        dfs(v, u);
        siz[u] += siz[v];
        md[u] = max(md[u], md[v]);
    }
}

void dfs1(int u, int fa) {
    upd(res[dep[u]], sa[u]);
    for (int v : V[u]) {
        if (vis[v] || v == fa) continue;
        dfs1(v, u);
    }
}

int id[N], tp;

bool cmp(int a, int b) {
    return md[a] < md[b];
}

void divide(int u) {
    vis[u] = true;
    tp = 0; int mm = 0;
    for (int v : V[u]) {
        if (vis[v]) continue;
        dep[v] = 1;
        dfs(v, u);
        id[++tp] = v;
    }
    sort(id + 1, id + tp + 1, cmp);
    now[0] = sa[u];
    for (int i = 1; i <= tp; i++) {
        int v = id[i];
        dfs1(v, u);
        int tt = NTT::solve(now, mm, res, md[v], ss);
        for (int i = 1; i <= tt; i++) upd(ans[i], ss[i]);
        for (int i = 0; i <= md[v]; i++) upd(now[i], res[i]);
        for (int i = 0; i <= md[v]; i++) res[i] = 0;
        for (int i = 0; i <= tt; i++) ss[i] = 0;
        mm = max(mm, md[v]);
    }
    for (int i = 0; i <= mm; i++) now[i] = 0;
    for (int v : V[u]) {
        if (vis[v]) continue;
        R = 0; MX = siz[v];
        getroot(v, u);
        divide(R);
    }

}

int main() {
    //freopen("0.txt", "r", stdin);
    int a, b;
    scanf("%d", &n);
    ll sum = 0;
    for (int i = 1; i <= n; i++) {
        scanf("%d", sa + i);
        sum += sa[i];
        if (sum >= MOD) sum -= MOD;
    }
    ll RR = qpow(sum, MOD - 2);
    for (int i = 1; i <= n; i++) {
        sa[i] = RR * sa[i] % MOD;
        ans[0] = (ans[0] + 1LL * sa[i] * sa[i]) % MOD;
    }
    for (int i = 0; i < n; i++) scanf("%d", ww + i);
    for (int i = 1; i < n; i++) {
        scanf("%d%d", &a, &b);
        V[a].push_back(b);
        V[b].push_back(a);
    }
    ms[0] = 1e9;
    MX = n;
    getroot(1, 0);
    divide(R);
    ll r = 1LL * ans[0] * ww[0] % MOD;
    for (int i = 1; i < n; i++) r = (r + 1LL * ans[i] * ww[i] * 2) % MOD;
    printf("%lld\n", r);
    return 0;
}