好路径的条件是“既要有小于等于 a 的点,又要有大于等于 b 的点”,那就把不好的先删掉:要么整条路都 >a,要么整条路都 <b。
所以答案=总路径数-两种坏路径,再把重复减掉的“整条路都在 (a,b) 里”加回去;而“某种条件下的路径数”在树上就是把符合条件的点分连通块,每块贡献 sz*(sz+1)/2。
void solve(){
int n,a,b;cin>>n>>a>>b;
vll w(n+1);
for(int i=1;i<=n;++i)cin>>w[i];
vvi g(n+1);
for(int i=1;i<n;++i){
int u,v;cin>>u>>v;
g[u].push_back(v);
g[v].push_back(u);
}
auto cal=[&](int t)->ll{
vb vis(n+1,0);
ll res=0;
vi st;
for(int i=1;i<=n;++i){
bool ok=0;
if(t==0)ok=(w[i]>a);
else if(t==1)ok=(w[i]<b);
else ok=(w[i]>a&&w[i]<b);
if(ok==0||vis[i])continue;
ll c=0;
st.clear();
st.push_back(i);
vis[i]=1;
while(!st.empty()){
int u=st.back();
st.pop_back();
++c;
for(int j=0;j<(int)g[u].size();++j){
int v=g[u][j];
bool ok2=0;
if(t==0)ok2=(w[v]>a);
else if(t==1)ok2=(w[v]<b);
else ok2=(w[v]>a&&w[v]<b);
if(ok2==0||vis[v])continue;
vis[v]=1;
st.push_back(v);
}
}
res+=c*(c+1)/2;
}
return res;
};
ll sum=1LL*n*(n+1)/2;
ll x=cal(0),y=cal(1),z=cal(2);
cout<<sum-x-y+z<<endl;
}

京公网安备 11010502036488号