G题最简单做法题解。看没人写dsu/线段树合并的题解来写一个,出题人只提了一嘴这个做法没有细说,这感觉也算经典线段树合并dp/计数的模式了。看出题人建立虚树的做法思路是通过三个点中的
来互相计数的,这里我们通过找
来计数其实更简单,问题转化为对于一个点
,考虑有几对相同
的
对路径经过它,且
。我用的线段树合并来解决这个问题,具体来说,合并到某个点
的时候,对于子树内相互匹配的点,我们可以边合并边计数,我们在
的时候先记一下
,若在线段树叶子节点发生合并,且合并位置的
,则此时会发生贡献,为
(假设合并的节点为
,且这种
分别有
个)。对于子树外和子树内匹配的点,我们在
子树合并完后统计,具体而言,我们这时候其实可以知道每种颜色在子树内有多少个,由于某种颜色全局数量是确定的,所以可以计算出子树外有多少个,则匹配数就是
(假设
全局有
个,当前有
个)。我们可以动态维护这个东西,然后合并完
子树后对这个东西在
求个区间和即可。
代码如下,时间复杂度。(不到100行,非常简单)
#include <bits/stdc++.h>
#define int long long
using namespace std;
const int maxn=1e5+5;
int n,m,w[maxn],b[maxn],ans,totw[maxn];
vector<int> G[maxn];
struct Segtree
{
int lc[maxn*40],rc[maxn*40],tot;
int sum[maxn*40],val[maxn*40];
void modify(int u,int l,int r,int x)
{
if(l==r)
{
val[u]++;
sum[u]=val[u]*(totw[l]-val[u]);
// cout<<l<<' '<<u<<' '<<val[u]<<' '<<sum[u]<<endl;
return;
}
int mid=(l+r)>>1;
if(x<=mid) modify(lc[u]=lc[u]?lc[u]:++tot,l,mid,x);
else modify(rc[u]=rc[u]?rc[u]:++tot,mid+1,r,x);
sum[u]=sum[lc[u]]+sum[rc[u]];
}
int query(int u,int l,int r,int x,int y)
{
if(!u) return 0;
if(x<=l&&r<=y) return sum[u];
int mid=(l+r)>>1,res=0;
if(x<=mid) res+=query(lc[u],l,mid,x,y);
if(y>mid) res+=query(rc[u],mid+1,r,x,y);
return res;
}
void merge(int u,int v,int l,int r,int noww)
{
if(l==r)
{
if(l<noww)
{
ans+=val[u]*val[v];
}
val[u]+=val[v];
sum[u]=val[u]*(totw[l]-val[u]);
return;
}
int mid=(l+r)>>1;
if(lc[u]&&lc[v]) merge(lc[u],lc[v],l,mid,noww);
else if(lc[v]) lc[u]=lc[v];
if(rc[u]&&rc[v]) merge(rc[u],rc[v],mid+1,r,noww);
else if(rc[v]) rc[u]=rc[v];
sum[u]=sum[lc[u]]+sum[rc[u]];
}
}wife;
void dfs(int u,int f)
{
for(auto v:G[u])
if(v!=f)
dfs(v,u);
wife.modify(u,1,m,w[u]);
for(auto v:G[u])
if(v!=f)
wife.merge(u,v,1,m,w[u]);
if(w[u]>1) ans+=wife.query(u,1,m,1,w[u]-1);
// cout<<u<<' '<<ans<<endl;
}
signed main()
{
ios::sync_with_stdio(false);
cin.tie(0);cout.tie(0);
cin>>n;
wife.tot=n;
for(int i=1;i<=n;++i)
cin>>w[i],b[i]=w[i];
sort(b+1,b+n+1);
m=unique(b+1,b+n+1)-b-1;
for(int i=1;i<=n;++i)
w[i]=lower_bound(b+1,b+m+1,w[i])-b,++totw[w[i]];
for(int i=1;i<n;++i)
{
int u,v;
cin>>u>>v;
G[u].push_back(v);
G[v].push_back(u);
}
dfs(1,0);
cout<<ans<<endl;
return 0;
}