题目链接:http://acm.split.hdu.edu.cn/showproblem.php?pid=5977

题意:给一个有n(1<=n<=5e4)个节点的树,每个节点有一个颜色,共有k(1<=k<=7)种不同的颜色,问有多少个点对(u,v)(注意(u,v)和(v,u)算作两个答案),这些点对之间的路径上有k个不同的颜色。

解法:树分治+高维前缀和,暴力枚举就超时了,但是只要知道高维前缀和以及枚举子集的相关知识,就能用O(k*(1<<k))结束统计,外面套上树分治,O(n*logn+n*k*(1<<k) )的时间。树分治还是那样的树分治,用状态压缩记录路径上遇到的点,接着用高维前缀和统计每个状态和他们的超集的个数,一次分治的计数方法就是num[i]*cnt[i^(1<<k-1)],num[i]记录的是状态i的出现次数,cnt[i]是状态i和i超集的总个数,统计每个i。高维前缀和是一种计数方法,只是短短的几行,用来计算包含i的集合的总个数。其他部分应该都ok。


//高维前缀和+树分治
#include <bits/stdc++.h>
using namespace std;
typedef long long LL;
const int maxn = 1e5+10;
const int maxm = 2e5+10;
struct edge{
    int v,next,w;
}E[maxm];
int head[maxn],edgecnt;
int n,vis[maxn],root;
LL ans;
void init(){
    memset(head,-1,sizeof(head));
    memset(vis,0,sizeof(vis));
    edgecnt=ans=0;
}
void add(int u,int v){
    E[edgecnt].v = v, E[edgecnt].next = head[u],head[u] = edgecnt++;
}
int mx[maxn],siz[maxn],col[maxn],mi,K,k;
LL cnt[maxn], cnt1[maxn];
void dfssize(int u, int fa){//处理子树的大小
    siz[u]=1;
    mx[u]=0;
    for(int i=head[u];~i;i=E[i].next){
        int v=E[i].v;
        if(v!=fa&&!vis[v]){
            dfssize(v, u);
            siz[u]+=siz[v];
            if(siz[v]>mx[u]) mx[u]=siz[v];
        }
    }
}
void dfsroot(int r, int u, int fa){//求重心
    if(siz[r]-siz[u]>mx[u]) mx[u]=siz[r]-siz[u];
    if(mx[u]<mi) mi=mx[u],root=u;
    for(int i=head[u]; ~i; i=E[i].next){
        int v=E[i].v;
        if(v!=fa&&!vis[v]) dfsroot(r, v, u);
    }
}
void dfs(int u, int fa, int sta){
    sta |= (1<<col[u]);
    ++cnt[sta];
    ++cnt1[sta];
    for(int i=head[u]; ~i; i=E[i].next){
        int v = E[i].v;
        if(!vis[v]&&v!=fa) dfs(v, u, sta);
    }
}
LL cal(int u, int sta){
    for(int i=0; i<=K; i++) cnt[i]=cnt1[i]=0;
    sta |= (1<<col[u]);
    ++cnt[sta];
    ++cnt1[sta];
    for(int i=head[u]; ~i; i=E[i].next){
        int v = E[i].v;
        if(!vis[v]) dfs(v, u, sta);
    }
    for(int i=0; i<k; i++)
        for(int j=0; j<K; j++)
            if(!(j&(1<<i))) cnt[j] += cnt[j|(1<<i)];
    LL ret = 0;
    for(int i=1; i<=K; i++) ret += cnt1[i]*cnt[(K^i)];
    return ret;
}
void DFS(int u){
    mi = n;
    dfssize(u, 0);
    dfsroot(u, u, 0);
    u = root;
    ans += cal(root, 0);
    vis[root] = 1;
    for(int i=head[root]; ~i; i=E[i].next){
        int v = E[i].v;
        if(!vis[v]){
            ans -= cal(v, (1<<col[u]));
            DFS(v);
        }
    }
}
int main()
{
    while(~scanf("%d%d", &n,&k)){
        K = (1<<k)-1;
        init();
        for(int i=1; i<=n; i++) scanf("%d", &col[i]), col[i]--;
        for(int i=1; i<n; i++){
            int u, v;
            scanf("%d %d", &u,&v);
            add(u, v);
            add(v, u);
        }
        DFS(1);
        printf("%lld\n", ans);
    }
    return 0;
}