考虑删除一个结点之后 LCA 会怎么变化
只考虑 A 树,当且仅当存在 个结点在lca的一颗子树上时,删除剩下的那个结点,才会使 LCA 变化。如果存在这种情况,我们只需要找到删除掉之后能使 LCA 产生变化的结点并记录下来。
对于 B 树我们拥有相同的结论。
所以我们只需要至多特判删除两个结点,每次暴力跑 LCA (可以树上前缀和求 LCA ),然后判断在 A , B 两个树上的 LCA 的权值大小关系。
对于剩下的不会使 LCA 发生改变的结点,我们只需要判断一开始在两颗树上的 LCA 的大小关系就好了。
复杂度:O(n)
#include <bits/stdc++.h>
#define int long long
#define endl '\n'
#define lowbit(x) (x&(-x))
#define ull unsigned long long
using namespace std;
const string yes="YES\n",no="NO\n";
const int mod = 1000000007,N = 100005,inf=1e18;
const ull base=13331;
int n,res,ans,sum,cnt,k,x;
int a[200005],b[200005],c[200005];
int vis[200005],suma[200005],sumb[200005];
int qpow(int x,int y=mod-2,int mo=mod,int res=1){
for(;y;y>>=1,(x*=x)%=mo) if(y&1)(res*=x)%=mo;
return res;
}
int lcaa,lcab,posa,posb,fga,fgb;
vector<int>pa[200005],pb[200005];
string s[2000005];
void dfsa(int u){
for(auto v:pa[u]){
dfsa(v);
suma[u]+=suma[v];
}
}
void dfsb(int u){
for(auto v:pb[u]){
dfsb(v);
sumb[u]+=sumb[v];
}
}
void findlcaa(int u,int x){
if(suma[u]==x){
lcaa=u;
}
for(auto v:pa[u]){
findlcaa(v,x);
}
}
void findlcab(int u,int x){
if(sumb[u]==x){
lcab=u;
}
for(auto v:pb[u]){
findlcab(v,x);
}
}
void getposa(int u){
if(vis[u])posa=u;
for(auto v:pa[u]){
getposa(v);
}
}
void getposb(int u){
if(vis[u])posb=u;
for(auto v:pb[u]){
getposb(v);
}
}
void solve(){
cin>>n>>k;
for(int i=1;i<=k;i++){
cin>>x;
c[i]=x;
vis[x]=1;
suma[x]=sumb[x]=1;
}
for(int i=1;i<=n;i++){
cin>>a[i];
}
for(int i=2;i<=n;i++){
cin>>x;
pa[x].push_back(i);
}
for(int i=1;i<=n;i++){
cin>>b[i];
}
for(int i=2;i<=n;i++){
cin>>x;
pb[x].push_back(i);
}
if(k==2){
cout<<(a[c[1]]>b[c[1]])+(a[c[2]]>b[c[2]]);
return;
}
dfsa(1);
dfsb(1);
findlcaa(1,k);
findlcab(1,k);
//cout<<lcaa<<" "<<lcab<<endl;
for(auto v:pa[lcaa]){
if(suma[v]==k-1){
fga=1;
}
}
for(auto v:pb[lcab]){
if(sumb[v]==k-1){
fgb=1;
}
}
if(fga){
for(auto v:pa[lcaa]){
if(suma[v]==1){
getposa(v);
}
}
if(posa==0)posa=lcaa;
}
if(fgb){
for(auto v:pb[lcab]){
if(sumb[v]==1){
getposb(v);
}
}
if(posb==0)posb=lcab;
}
if(posa==posb){
fgb=0;
}
if(a[lcaa]>b[lcab]){
ans=k-fga-fgb;
}
if(fga){
for(int i=1;i<=n;i++){
suma[i]=sumb[i]=vis[i];
}
suma[posa]=sumb[posa]=0;
dfsa(1);
dfsb(1);
findlcaa(1,k-1);
findlcab(1,k-1);
if(a[lcaa]>b[lcab])ans++;
}
if(fgb){
for(int i=1;i<=n;i++){
suma[i]=sumb[i]=vis[i];
}
suma[posb]=sumb[posb]=0;
dfsa(1);
dfsb(1);
findlcaa(1,k-1);
findlcab(1,k-1);
if(a[lcaa]>b[lcab])ans++;
}
cout<<ans;
}
void main_init(){}
signed main(){
ios::sync_with_stdio(false);cin.tie(0);cout.tie(0);
cout<<fixed<<setprecision(12);
int t=1;
main_init();
//cin>>t;
while (t--)
solve();
}