先把每个点从根到它的和、整棵子树的和预处理出来,先算一遍不删边时的支撑点总数。然后再跑一遍非递归DFS,用树状数组维护祖先里哪些点会因为“删掉当前子树”而变成支撑点,最后每条边都按删前总数-被删子树里的支撑点+新变成的支撑点算一下,取最大值。

void solve(){
    int n;cin>>n;
    vll a(n+1);
    for(int i=1;i<=n;++i)cin>>a[i];
    vi p(n+1);
    for(int i=1;i<=n;++i)cin>>p[i];
    vvi g(n+1);
    for(int i=2;i<=n;++i)g[p[i]].push_back(i);

    vll pre(n+1),sub(n+1);
    for(int i=1;i<=n;++i){
        pre[i]=pre[p[i]]+a[i];
        sub[i]=a[i];
    }
    for(int i=n;i>=2;--i)sub[p[i]]+=sub[i];

    vb c1(n+1),c2(n+1),sup(n+1),sp(n+1);
    vll nd(n+1),vals;
    ll base=0;
    for(int i=1;i<=n;++i){
        c1[i]=(pre[i]>=2*a[i]);
        c2[i]=(sub[i]<=2*a[i]);
        sup[i]=(c1[i]&&c2[i]);
        base+=sup[i];
        if(c1[i]&&!c2[i]){
            sp[i]=1;
            nd[i]=sub[i]-2*a[i];
            vals.push_back(nd[i]);
        }
    }

    vll cnt(n+1);
    for(int i=1;i<=n;++i)cnt[i]=sup[i];
    for(int i=n;i>=2;--i)cnt[p[i]]+=cnt[i];

    sort(all(vals));
    vals.erase(unique(all(vals)),vals.end());
    int m=vals.size();
    vi bit(m+1);

    auto add=[&](int x,int v){
        for(int i=x;i<=m;i+=i&-i)bit[i]+=v;
    };
    auto ask=[&](int x){
        int s=0;
        for(int i=x;i>0;i-=i&-i)s+=bit[i];
        return s;
    };
    auto id=[&](ll x){
        return int(lower_bound(all(vals),x)-vals.begin())+1;
    };
    auto le=[&](ll x){
        return int(upper_bound(all(vals),x)-vals.begin());
    };

    vll gain(n+1);
    vi st,op;
    st.reserve(2*n+5);
    op.reserve(2*n+5);
    st.push_back(1);op.push_back(0);
    while(!st.empty()){
        int u=st.back();st.pop_back();
        int t=op.back();op.pop_back();
        if(t==0){
            gain[u]=ask(le(sub[u]));
            if(sp[u])add(id(nd[u]),1);
            st.push_back(u);
            op.push_back(1);
            for(int i=(int)g[u].size()-1;i>=0;--i){
                int v=g[u][i];
                st.push_back(v);
                op.push_back(0);
            }
        }else{
            if(sp[u])add(id(nd[u]),-1);
        }
    }

    ll ans=base;
    for(int v=2;v<=n;++v){
        ll cur=base-cnt[v]+gain[v];
        if(cur>ans)ans=cur;
    }
    cout<<ans<<endl;
}