[AtCoder Beginner Contest 163] F - path pass i (树型dfs,容斥定律)


链接:https://atcoder.jp/contests/abc163/tasks/abc163_f

Problem Statement

We have a tree with NN vertices numbered 11 to NN. The ii-th edge in this tree connects Vertex aiai and bibi. Additionally, each vertex is painted in a color, and the color of Vertex ii is cici. Here, the color of each vertex is represented by an integer between 11 and NN (inclusive). The same integer corresponds to the same color; different integers correspond to different colors.

For each k=1,2,...,Nk=1,2,...,N, solve the following problem:

  • Find the number of simple paths that visit a vertex painted in the color kk one or more times.

Note: The simple paths from Vertex uu to vv and from vv to uu are not distinguished.

题意:

给定一个含有\(\mathit n\)个节点的树,其中第\(\mathit i\)个节点的颜色为\(c_i\)

现在问对于\(k=1, 2, ..., N\),输出有多少个简单路径有经过颜色为\(\mathit k\)的节点。

其中\(u->v\)\(v->u\)看成同一个路径。

思路:

\(ans_i\)代表有经过颜色为$\mathit i $的节点的简单路径数量。

那么我们可以反向来求,\(ans_i=\frac{n*(n+1)}{2}-b_i\),其中\(b_i\)代表路径中不经过颜色为\(\mathit i\)的节点的简单路径个数。

如何来求呢?

我们思考可以得到,路径中不经过颜色为\(\mathit i\)的节点的简单路径只存在于以下两种情况:

1️⃣:两个颜色为\(\mathit i\)的节点之间的节点相互到达的路径。

2️⃣:颜色为\(\mathit i\)的节点到[路径中不经过其他颜色为\(\mathit i\)的节点]的叶子节点之间的节点相互到达的路径。

如下图:

路径:\((7,7),(7,8),(7,6),(8,6)\dots\) 属于第一种。

路径:\((1,1),(1,3),(3,3),(4,4)\)属于第二种。

具体求法:

在树形dfs过程中,在当前节点\(\mathit x\)访问一个它的儿子节点\(\mathit y\)节点之前,记录当前

\(now=num[c_x]\)以颜色为\(c[x]\)的节点为根节点的子树大小总和,

在dfs儿子节点$\mathit y $ 之后,

\(num[c_x]-now\)代表在\(\mathit y\)节点为根节点的子树中到达节点\(\mathit x\)会经过颜色为\(c_x\)的节点个数,

\(cntson[x]\)代表以\(\mathit x\)节点为根节点的子树大小,

那么\(t=cntson[x]-(num[c_x]-now)\)就代表在\(\mathit y\)节点为根节点的子树中到达节点\(\mathit x\)之前不会经过颜色为\(c_x\)的节点个数,那么更新答案\(ans_{c[x]}-=\frac{t*(t+1)}{2}\)

在访问所有子节点之后,更新\(num[c_x]\)

在整个dfs结束后,对于颜色\(\mathit i\)\((n-num[c_i])\)也是属于第二种情况的节点,减去它们之间的路径数量即可。

代码:

#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <cmath>
#include <queue>
#include <stack>
#include <map>
#include <set>
#include <vector>
#include <iomanip>
#include <sstream>
#include <bitset>
#include <unordered_map>
// #include <bits/stdc++.h>
#define ALL(x) (x).begin(), (x).end()
#define sz(a) int(a.size())
#define rep(i,x,n) for(int i=x;i<n;i++)
#define repd(i,x,n) for(int i=x;i<=n;i++)
#define pii pair<int,int>
#define pll pair<long long ,long long>
#define gbtb ios::sync_with_stdio(false),cin.tie(0),cout.tie(0)
#define MS0(X) memset((X), 0, sizeof((X)))
#define MSC0(X) memset((X), '\0', sizeof((X)))
#define pb push_back
#define mp make_pair
#define fi first
#define se second
#define eps 1e-6
#define chu(x)  if(DEBUG_Switch) cout<<"["<<#x<<" "<<(x)<<"]"<<endl
#define du3(a,b,c) scanf("%d %d %d",&(a),&(b),&(c))
#define du2(a,b) scanf("%d %d",&(a),&(b))
#define du1(a) scanf("%d",&(a));
using namespace std;
typedef long long ll;
ll gcd(ll a, ll b) {return b ? gcd(b, a % b) : a;}
ll lcm(ll a, ll b) {return a / gcd(a, b) * b;}
ll powmod(ll a, ll b, ll MOD) { if (a == 0ll) {return 0ll;} a %= MOD; ll ans = 1; while (b) {if (b & 1) {ans = ans * a % MOD;} a = a * a % MOD; b >>= 1;} return ans;}
ll poww(ll a, ll b) { if (a == 0ll) {return 0ll;} ll ans = 1; while (b) {if (b & 1) {ans = ans * a ;} a = a * a ; b >>= 1;} return ans;}
void Pv(const vector<int> &V) {int Len = sz(V); for (int i = 0; i < Len; ++i) {printf("%d", V[i] ); if (i != Len - 1) {printf(" ");} else {printf("\n");}}}
void Pvl(const vector<ll> &V) {int Len = sz(V); for (int i = 0; i < Len; ++i) {printf("%lld", V[i] ); if (i != Len - 1) {printf(" ");} else {printf("\n");}}}
inline long long readll() {long long tmp = 0, fh = 1; char c = getchar(); while (c < '0' || c > '9') {if (c == '-') fh = -1; c = getchar();} while (c >= '0' && c <= '9') tmp = tmp * 10 + c - 48, c = getchar(); return tmp * fh;}
inline int readint() {int tmp = 0, fh = 1; char c = getchar(); while (c < '0' || c > '9') {if (c == '-') fh = -1; c = getchar();} while (c >= '0' && c <= '9') tmp = tmp * 10 + c - 48, c = getchar(); return tmp * fh;}
void pvarr_int(int *arr, int n, int strat = 1) {if (strat == 0) {n--;} repd(i, strat, n) {printf("%d%c", arr[i], i == n ? '\n' : ' ');}}
void pvarr_LL(ll *arr, int n, int strat = 1) {if (strat == 0) {n--;} repd(i, strat, n) {printf("%lld%c", arr[i], i == n ? '\n' : ' ');}}
const int maxn = 1000010;
const int inf = 0x3f3f3f3f;
/*** TEMPLATE CODE * * STARTS HERE ***/
#define DEBUG_Switch 0
int n;
std::vector<int> v[maxn];
int cnt_son[maxn];
ll ans[maxn];
ll calculate(ll x)
{
    return x * (x + 1) / 2;
}
int col[maxn];
void build()
{
    n = readint();
    int x, y;
    repd(i, 1, n)
    {
        col[i] = readint();
    }
    repd(i, 1, n - 1)
    {
        x = readint();
        y = readint();
        v[x].pb(y);
        v[y].pb(x);
    }
}
int num[maxn];
void dfs(int x, int pre)
{
    cnt_son[x] = 1;
    int c = col[x];
    int temp = num[c];
    for (auto &y : v[x])
    {
        if (y != pre)
        {
            int now = num[c];
            dfs(y, x);
            now = num[c] - now;
            ans[c] -= calculate(cnt_son[y] - now);
            cnt_son[x] += cnt_son[y];
        }
    }
    num[c] = temp + cnt_son[x];
}
int main()
{
#if DEBUG_Switch
    freopen("C:\\code\\input.txt", "r", stdin);
#endif
    //freopen("C:\\code\\output.txt","r",stdin);
    build();
    repd(i, 1, n)
    {
        ans[i] = calculate(n);
    }
    dfs(1, 0);
    repd(i, 1, n)
    {
        ans[i] -= calculate(n - num[i]);
        printf("%lld\n", ans[i] );
    }
    return 0;
}