Problem Description
There is a tree with n nodes, each of which has a type of color represented by an integer, where the color of node i is ci.
The path between each two different nodes is unique, of which we define the value as the number of different colors appearing in it.
Calculate the sum of values of all paths on the tree that has n(n−1)2 paths in total.
Input
The input contains multiple test cases.
For each test case, the first line contains one positive integers n, indicating the number of node. (2≤n≤200000)
Next line contains n integers where the i-th integer represents ci, the color of node i. (1≤ci≤n)
Each of the next n−1 lines contains two positive integers x,y (1≤x,y≤n,x≠y), meaning an edge between node x and node y.
It is guaranteed that these edges form a tree.
Output
For each test case, output “Case #x: y” in one line (without quotes), where x indicates the case number starting from 1 and y denotes the answer of corresponding case.
Sample Input
3
1 2 1
1 2
2 3
6
1 2 1 3 2 1
1 2
1 3
2 4
2 5
3 6
Sample Output
Case #1: 6
Case #2: 29
解法:首先来考虑一种颜色怎么计算?
我用画图板画了一颗树,不要吐槽。然后现在1,5,6,9是红色的。对于红色我们如何计算呢?如果直接计算会发现非常麻烦,基本是算不到的。所以我们可以反着来计算,计算不经过红色的点的路径条数,然后用总路径数n*(n-1)/2减掉就是答案了,可以发现图中的话就是n*(n-1)/2减掉我用椭圆圈卡来的联通块的方案数的和,也就是3*(3-1)/2+2*(2-1)/2。那么在对于这个图就计算出来红色的了。然后对于每一种颜色单独计算答案,然后求个和就可以算出最后的答案了。也就是假设有cnt个颜色,答案就是cnt*(n*(n-1)/2)-sum(不贡献答案的值)。前面可以直接算,难点是后面怎么算,这也是本题最关键的问题。我们发现对于途中红色的点1,我们如何计算1和5,6之间的这个椭圆包含的联通快的大小呢?
我们想想,要是记录下红色在DFS过程中是怎么加进我们的递归栈中的,那么对于当前红色节点,对于他的子树如果把以他为DFS回溯过程中作为第一个碰到的点的话,那么这些子节点(必须是红色)的size减掉就是当前红色节点和5,6之间的联通快的大小了。好了,现在的话,我们可以算出来根节点的子节点的值呢?那么如果根节点不是这种颜色呢?
也就是把1节点的颜色变成黑色,那么按照我们的方法是算出来的影响联通快的大小是0的,但是实际上影响联通块的大小是1,2,3,4,7,8。所以如何解决这个问题呢?这里引入一个虚根,颜色设成0,想想现在的1节点就相当于刚才的字树的节点,就可以计算了。但是1节点的颜色可以是n种,所以对于每种可能出现的颜色都要设一个虚根。
对于DFS记录当前颜色节点上一次最早出现相同的节点显然可以可以用栈记录,但是用栈记录,我队MLE了。所以我们可以改成vector,就可以过了。
复杂度:O(n)
#include <bits/stdc++.h>
using namespace std;
const int mod = 1e9+7;
const int maxn = 200010;
const int maxm = 2*maxn;
typedef long long LL;
struct edge{
int to, next;
}E[maxm];
int n, head[maxn], edgecnt, sz[maxn], root[maxn], s[maxn], col[maxn];
bool f[maxn];
//root[i]表示i这种节点的虚结点
//s[i]表示i这个节点的对答案的贡献
//col[i]表示i这个节点的颜色
LL sum;
void init(){
edgecnt=0;
memset(head,-1,sizeof(head));
}
void add(int u, int v){
E[edgecnt].to = v, E[edgecnt].next = head[u], head[u] = edgecnt++;
}
vector <int> last[maxn];//记录当前节点出现颜色上一次出现的最近的点
void dfs1(int x, int fa){
sz[x] = 1;
for(int i = head[x]; ~i; i=E[i].next){
int v = E[i].to;
if(v == fa) continue;
dfs1(v, x);
sz[x] += sz[v];
}
}
void dfs2(int u, int fa){
int ls = last[col[u]].back();
if(ls == 0){
root[col[u]] += sz[u];
}
else{
s[ls] += sz[u];
}
last[col[u]].push_back(u);
for(int i=head[u]; ~i; i = E[i].next){
int v = E[i].to;
if(v == fa) continue;
s[u] = 0;
dfs2(v, u);
s[u] = sz[v] - s[u];
sum += 1LL*s[u]*(s[u]-1)/2;
}
last[col[u]].pop_back();
}
int main()
{
for(int i=1; i<maxn; i++) last[i].push_back(0);
int ks = 0;
while(~scanf("%d", &n)){
init();
memset(f, 0, sizeof(f));
memset(root, 0, sizeof(root));
sum = 0;
for(int i=1; i<=n; i++){
scanf("%d ", &col[i]);
f[col[i]] = 1;
}
for(int i=1; i<n; i++){
int u, v;
scanf("%d %d", &u,&v);
add(u, v);
add(v, u);
}
dfs1(1, -1);
dfs2(1, -1);
int cnt = 0;
for(int i=1; i<=n; i++){
if(f[i]){
cnt++;
int temp = sz[1] - root[i];
sum += 1LL*temp*(temp-1)/2;
}
}
LL ans = 1LL * cnt * (n-1) * n / 2;
ans -= sum;
printf("Case #%d: %lld\n", ++ks, ans);
}
return 0;
}