贴个代码吧

#include <bits/stdc++.h>

#define x first
#define y second
#define all(x) x.begin(), x.end()
#define vec1(T, name, n, val) vector<T> name(n, val)
#define vec2(T, name, n, m, val) vector<vector<T>> name(n, vector<T>(m, val))
#define vec3(T, name, n, m, k, val) vector<vector<vector<T>>> name(n, vector<vector<T>>(m, vector<T>(k, val)))
#define vec4(T, name, n, m, k, p, val) vector<vector<vector<vector<T>>>> name((n), vector<vector<vector<T>>>((m), vector<vector<T>>((k), vector<T>((p), (val)))))

using namespace std;
using i128 = __int128;
using u128 = unsigned __int128;
using LL = long long;
using LD = long double;
using ULL = unsigned long long;
using PII = pair<int, int>;
using PLL = pair<LL, LL>;
using PLD = pair<LD, LD>;

const int N = 1e5 + 10, MOD = 998244353;
const int INF = 1e9;
const LL LL_INF = 2e18;
const LD EPS = 1e-8;
const int dx4[] = {-1, 0, 1, 0}, dy4[] = {0, 1, 0, -1};
const int dx8[] = {-1, -1, -1, 0, 0, 1, 1, 1}, dy8[] = {-1, 0, 1, -1, 1, -1, 0, 1};
const int hx[] = {-2, -2, -1, -1, 1, 1, 2, 2}, hy[] = {-1, 1, -2, 2, -2, 2, -1, 1};

istream& operator>>(istream& is, i128& val) {
    string str;
    is >> str;
    val = 0;
    bool flag = false;
    if (str[0] == '-') flag = true, str = str.substr(1);
    for (char& c : str) val = val * 10 + c - '0';
    if (flag) val = -val;
    return is;
}

ostream& operator<<(ostream& os, i128 val) {
    if (val < 0) os << "-", val = -val;
    if (val > 9) os << val / 10;
    os << static_cast<char>(val % 10 + '0');
    return os;
}

bool cmp(LD a, LD b) {
    if (fabs(a - b) < EPS) return 1;
    return 0;
}

LL qpow(LL a, LL b) {
    LL ans = 1;
    a %= MOD;
    while (b) {
        if (b & 1) ans = ans * a % MOD;
        a = a * a % MOD;
        b >>= 1;
    }
    return ans;
}

struct Fenwick {
    int n;
    vector<LL> tr;

    Fenwick(int _n) : n(_n + 1), tr(_n + 1, 0) {
    }

    int lowbit(int x) {
        return x & -x;
    }

    void add(int u, LL x) {
        if (u <= 0) return;
        for (int i = u; i < n; i += lowbit(i)) tr[i] += x;
    }

    LL query(int u) {
        u = min(u, n - 1);
        LL ans = 0;
        for (int i = u; i; i -= lowbit(i)) ans += tr[i];
        return ans;
    }

    LL query(int a, int b) {
        if (a > b) return 0;
        return query(b) - query(a - 1);
    }

    LL kth(LL k) {
        int x = 0;
        for (int p = 1 << 20; p; p >>= 1) {
            if (x + p <= n && tr[x + p] < k) {
                k -= tr[x + p];
                x += p;
            }
        }
        return x + 1;
    }
};

struct Disc {
    bool flg;
    vector<LL> a;

    Disc() : flg(0) {
    }

    void add(LL v) {
        if (flg == 1) return;
        a.emplace_back(v);
    }

    void init() {
        if (flg) return;
        flg = 1;
        sort(all(a));
        a.erase(unique(all(a)), a.end());
    }

    int get(LL v) {
        if (!flg) {
            init();
        }
        return lower_bound(all(a), v) - a.begin() + 1;
    }
    
    int upget(LL v) {
        if (!flg) {
            init();
        }
        return upper_bound(all(a), v) - a.begin() + 1;
    }

    int sz() {
        if (!flg) {
            init();
        }
        return a.size();
    }
};

void solve() {
    int n;
    cin >> n;
    vector<int> a(n + 1);
    for (int i = 1; i <= n; ++i) cin >> a[i];
    vector<vector<int>> g(n + 1);
    for (int i = 1; i <= n; ++i){
        int fa;
        cin >> fa;
        if (i == 1) continue;
        g[fa].push_back(i);
        g[i].push_back(fa);
    }

    LL sum = 0;
    Disc d;
    vector<int> is(n + 1);
    vector<LL> f(n + 1);
    vector<int> pcnt(n + 1);
    auto dfs1 = [&](auto dfs, int u, int fa) -> void {
        sum += a[u];
        f[u] = a[u];
        for (int v : g[u]) {
            if (v == fa) continue;
            dfs(dfs, v, u);
            f[u] += f[v];
            pcnt[u] += pcnt[v];
        }

        if (sum - a[u] >= a[u] && f[u] - a[u] <= a[u]) {
            is[u] = 1;
            pcnt[u]++;
        }
        d.add(f[u] - 2 * a[u]);
        sum -= a[u];
    };

    dfs1(dfs1, 1, 0);

    Fenwick tr(d.sz() + 1);
    LL tot = pcnt[1];
    LL ans = tot;
    sum = 0;
    auto dfs2 = [&](auto dfs, int u, int fa) -> void {
        sum += a[u];
        if (sum - a[u] >= a[u] && !is[u]) {
            tr.add(d.get(f[u] - 2 * a[u]), 1);
        }
        for (int v : g[u]) {
            if (v == fa) continue;
            LL cur = tot - pcnt[v] + tr.query(d.upget(f[v]) - 1);
            ans = max(ans, cur);
            dfs(dfs, v, u);
        }
        if (sum - a[u] >= a[u] && !is[u]) {
            tr.add(d.get(f[u] - 2 * a[u]), -1);
        }
        sum -= a[u];
    };

    dfs2(dfs2, 1, 0);

    cout << ans << '\n';

/**/ #ifdef LOCAL
    cout << flush;
/**/ #endif
}

int main() {
    ios::sync_with_stdio(false);
    cin.tie(0), cout.tie(0);

    int T = 1;
    while (T--) solve();
    cout << fixed << setprecision(15);

    return 0;
}