ACM模版

描述

题解

其实这就是一道树归题而已,比赛时就知道,但是时间不够写了……给一下官方题解吧~~~

单独考虑每一种颜色,答案就是对于每种颜色至少经过一次这种的路径条数之和。
反过来思考只需要求有多少条路径没有经过这种颜色即可。
直接做可以采用虚树的思想(不用真正建出来),
对每种颜色的点按照 dfs 序列排个序,
就能求出这些点把原来的树划分成的块的大小。
这个过程实际上可以直接一次 dfs 求出。

这个题解对于树归学得一般的人来说,可能并不容易理解,其实这个题的核心就是虚树以及如何划分块儿,看下图:

对于这个子树,他的结点个数是 13 ,而它的子孙后代中,以红色结点为根的子树的结点个数之和为 11 ,那么,我们可以确定的是,只剩下两个结点划分为一块儿,这个块儿内的所有点之间的路径均为经过红色结点,那么你可能会问,这个连通块儿不会向上延伸吗?所以,这里有一个限制,如图中的虚根,是针对于整个树而言的,而对于每个子树求联通块儿,必须子树的根结点的父亲的颜色是红色才行,只有这样才能求到一个封闭的联通块儿,所以这里的 rt 对于整个树而言,是每一种颜色的结点都要算一下,而对于子树而言, rt 的颜色实际上就是子树的根的父亲的颜色。所以,不管是全局还是局部,这个规律都具有一般性,只不过全局来看,根本无法向上延伸了,所以就要考虑所有的情况罢了!

对了,我这里画的图还是意图的不够彻底,不要误以为连通块儿内只能是一种颜色,你大可以将剩下的两个结点的颜色随便改,除了红色外。一开始画图时我是改成了一个橙色一个绿色,后来可能多撤销了一次,结果还是原来的两个橙色,当然,这已经是废话了……

代码

#include <iostream>
#include <algorithm>
#include <cstdio>
#include <cstring>

typedef long long ll;

const int MAXN = 2e5 + 10;

int n;
ll ans;
int c[MAXN];
int lnk[MAXN];
int pos[MAXN];
int ctr[MAXN];  // 某种颜色的虚树结点个数
int rem[MAXN];  // 某种颜色的所有虚树的结点个数

struct Edge
{
    int nxt, v;
} e[MAXN << 1];

inline ll get_cnt(int x)
{
    return (x * (x - 1LL)) >> 1;
}

int dfs(int rt, int pre)
{
    int su = 1, o = pos[c[rt]];
    pos[c[rt]] = rt;
    for (int it = lnk[rt]; it; it = e[it].nxt)
    {
        if (e[it].v != pre)
        {
            ctr[rt] = 0;
            int sv = dfs(e[it].v, rt);
            ans -= get_cnt(sv - ctr[rt]);
            su += sv;
        }
    }
    (o ? ctr[o] : rem[c[rt]]) += su;
    pos[c[rt]] = o;

    return su;  // 返回子树的结点个数
}

template <class T>
inline void scan_d(T &ret)
{
    char c;
    ret = 0;
    while ((c = getchar()) < '0' || c > '9');
    while (c >= '0' && c <= '9')
    {
        ret = ret * 10 + (c - '0'), c = getchar();
    }
}

int main()
{
    int ce = 1;
    while (~scanf("%d", &n))
    {
        memset(lnk, 0, sizeof(lnk));
        memset(rem, 0, sizeof(rem));

        for (int i = 1; i <= n; ++i)
        {
            scan_d(c[i]);
        }

        int u, v;
        for (int i = 1; i < n; ++i)
        {
            scan_d(u), scan_d(v);
            e[i << 1] = (Edge){lnk[u], v};
            lnk[u] = i << 1;
            e[i << 1 | 1] = (Edge){lnk[v], u};
            lnk[v] = i << 1 | 1;
        }

        ans = get_cnt(n) * n;

        dfs(1, -1);

        for (int i = 1; i <= n; ++i)
        {
            ans -= get_cnt(n - rem[i]);
        }

        printf("Case #%d: %lld\n", ce++, ans);
    }

    return 0;
}