考虑删除一个结点之后 LCA 会怎么变化

只考虑 A 树,当且仅当存在 k1k-1 个结点在lca的一颗子树上时,删除剩下的那个结点,才会使 LCA 变化。如果存在这种情况,我们只需要找到删除掉之后能使 LCA 产生变化的结点并记录下来。

对于 B 树我们拥有相同的结论。

所以我们只需要至多特判删除两个结点,每次暴力跑 LCA (可以树上前缀和求 LCA ),然后判断在 A , B 两个树上的 LCA 的权值大小关系。

对于剩下的不会使 LCA 发生改变的结点,我们只需要判断一开始在两颗树上的 LCA 的大小关系就好了。

复杂度:O(n)

#include <bits/stdc++.h>
#define int long long
#define endl '\n'
#define lowbit(x) (x&(-x))
#define ull unsigned long long
using namespace std;
const string yes="YES\n",no="NO\n";
const int mod = 1000000007,N = 100005,inf=1e18;
const ull base=13331;
int n,res,ans,sum,cnt,k,x;
int a[200005],b[200005],c[200005];
int vis[200005],suma[200005],sumb[200005];
int qpow(int x,int y=mod-2,int mo=mod,int res=1){
    for(;y;y>>=1,(x*=x)%=mo) if(y&1)(res*=x)%=mo;
    return res;
}
int lcaa,lcab,posa,posb,fga,fgb;
vector<int>pa[200005],pb[200005];
string s[2000005];
void dfsa(int u){
    for(auto v:pa[u]){
        dfsa(v);
        suma[u]+=suma[v];
    }
}
void dfsb(int u){
    for(auto v:pb[u]){
        dfsb(v);
        sumb[u]+=sumb[v];
    }
}
void findlcaa(int u,int x){
    if(suma[u]==x){
        lcaa=u;
    }
    for(auto v:pa[u]){
        findlcaa(v,x);
    }
}
void findlcab(int u,int x){
    if(sumb[u]==x){
        lcab=u;
    }
    for(auto v:pb[u]){
        findlcab(v,x);
    }
}
void getposa(int u){
    if(vis[u])posa=u;
    for(auto v:pa[u]){
        getposa(v);
    }
}
void getposb(int u){
    if(vis[u])posb=u;
    for(auto v:pb[u]){
        getposb(v);
    }
}

void solve(){
    cin>>n>>k;
    for(int i=1;i<=k;i++){
        cin>>x;
        c[i]=x;
        vis[x]=1;
        suma[x]=sumb[x]=1;
    }
    for(int i=1;i<=n;i++){
        cin>>a[i];
    }
    for(int i=2;i<=n;i++){
        cin>>x;
        pa[x].push_back(i);
    }
    for(int i=1;i<=n;i++){
        cin>>b[i];
    }
    for(int i=2;i<=n;i++){
        cin>>x;
        pb[x].push_back(i);
    }
    if(k==2){
        cout<<(a[c[1]]>b[c[1]])+(a[c[2]]>b[c[2]]);
        return;
    }
    dfsa(1);
    dfsb(1);
    findlcaa(1,k);
    findlcab(1,k);
    //cout<<lcaa<<" "<<lcab<<endl;
    for(auto v:pa[lcaa]){
        if(suma[v]==k-1){
            fga=1;
        }
    }
    for(auto v:pb[lcab]){
        if(sumb[v]==k-1){
            fgb=1;
        }
    }
    if(fga){
        for(auto v:pa[lcaa]){
            if(suma[v]==1){
                getposa(v);
            }
        }
        if(posa==0)posa=lcaa;
    }
    if(fgb){
        for(auto v:pb[lcab]){
            if(sumb[v]==1){
                getposb(v);
            }
        }
        if(posb==0)posb=lcab;
    }
    if(posa==posb){
        fgb=0;
    }
    if(a[lcaa]>b[lcab]){
        ans=k-fga-fgb;
    }
    if(fga){
        for(int i=1;i<=n;i++){
            suma[i]=sumb[i]=vis[i];
        }
        suma[posa]=sumb[posa]=0;
        dfsa(1);
        dfsb(1);
        findlcaa(1,k-1);
        findlcab(1,k-1);
        if(a[lcaa]>b[lcab])ans++;
    }
    if(fgb){
        for(int i=1;i<=n;i++){
            suma[i]=sumb[i]=vis[i];
        }
        suma[posb]=sumb[posb]=0;
        dfsa(1);
        dfsb(1);
        findlcaa(1,k-1);
        findlcab(1,k-1);
        if(a[lcaa]>b[lcab])ans++;
    }
    cout<<ans;

}
void main_init(){}
signed main(){
    ios::sync_with_stdio(false);cin.tie(0);cout.tie(0);
	cout<<fixed<<setprecision(12);
    int t=1;
	main_init();
    //cin>>t;
    while (t--)
        solve();
    
		
}