题目链接:http://acm.split.hdu.edu.cn/showproblem.php?pid=5977
题意:给一个有n(1<=n<=5e4)个节点的树,每个节点有一个颜色,共有k(1<=k<=7)种不同的颜色,问有多少个点对(u,v)(注意(u,v)和(v,u)算作两个答案),这些点对之间的路径上有k个不同的颜色。
解法:树分治+高维前缀和,暴力枚举就超时了,但是只要知道高维前缀和以及枚举子集的相关知识,就能用O(k*(1<<k))结束统计,外面套上树分治,O(n*logn+n*k*(1<<k) )的时间。树分治还是那样的树分治,用状态压缩记录路径上遇到的点,接着用高维前缀和统计每个状态和他们的超集的个数,一次分治的计数方法就是num[i]*cnt[i^(1<<k-1)],num[i]记录的是状态i的出现次数,cnt[i]是状态i和i超集的总个数,统计每个i。高维前缀和是一种计数方法,只是短短的几行,用来计算包含i的集合的总个数。其他部分应该都ok。
//高维前缀和+树分治
#include <bits/stdc++.h>
using namespace std;
typedef long long LL;
const int maxn = 1e5+10;
const int maxm = 2e5+10;
struct edge{
int v,next,w;
}E[maxm];
int head[maxn],edgecnt;
int n,vis[maxn],root;
LL ans;
void init(){
memset(head,-1,sizeof(head));
memset(vis,0,sizeof(vis));
edgecnt=ans=0;
}
void add(int u,int v){
E[edgecnt].v = v, E[edgecnt].next = head[u],head[u] = edgecnt++;
}
int mx[maxn],siz[maxn],col[maxn],mi,K,k;
LL cnt[maxn], cnt1[maxn];
void dfssize(int u, int fa){//处理子树的大小
siz[u]=1;
mx[u]=0;
for(int i=head[u];~i;i=E[i].next){
int v=E[i].v;
if(v!=fa&&!vis[v]){
dfssize(v, u);
siz[u]+=siz[v];
if(siz[v]>mx[u]) mx[u]=siz[v];
}
}
}
void dfsroot(int r, int u, int fa){//求重心
if(siz[r]-siz[u]>mx[u]) mx[u]=siz[r]-siz[u];
if(mx[u]<mi) mi=mx[u],root=u;
for(int i=head[u]; ~i; i=E[i].next){
int v=E[i].v;
if(v!=fa&&!vis[v]) dfsroot(r, v, u);
}
}
void dfs(int u, int fa, int sta){
sta |= (1<<col[u]);
++cnt[sta];
++cnt1[sta];
for(int i=head[u]; ~i; i=E[i].next){
int v = E[i].v;
if(!vis[v]&&v!=fa) dfs(v, u, sta);
}
}
LL cal(int u, int sta){
for(int i=0; i<=K; i++) cnt[i]=cnt1[i]=0;
sta |= (1<<col[u]);
++cnt[sta];
++cnt1[sta];
for(int i=head[u]; ~i; i=E[i].next){
int v = E[i].v;
if(!vis[v]) dfs(v, u, sta);
}
for(int i=0; i<k; i++)
for(int j=0; j<K; j++)
if(!(j&(1<<i))) cnt[j] += cnt[j|(1<<i)];
LL ret = 0;
for(int i=1; i<=K; i++) ret += cnt1[i]*cnt[(K^i)];
return ret;
}
void DFS(int u){
mi = n;
dfssize(u, 0);
dfsroot(u, u, 0);
u = root;
ans += cal(root, 0);
vis[root] = 1;
for(int i=head[root]; ~i; i=E[i].next){
int v = E[i].v;
if(!vis[v]){
ans -= cal(v, (1<<col[u]));
DFS(v);
}
}
}
int main()
{
while(~scanf("%d%d", &n,&k)){
K = (1<<k)-1;
init();
for(int i=1; i<=n; i++) scanf("%d", &col[i]), col[i]--;
for(int i=1; i<n; i++){
int u, v;
scanf("%d %d", &u,&v);
add(u, v);
add(v, u);
}
DFS(1);
printf("%lld\n", ans);
}
return 0;
}