https://ac.nowcoder.com/acm/contest/317/I

C++版本一

std

题解:启发式合并

#include<cstdio>
#include<algorithm>
#include<cstring>
#include<queue>
#include<iostream>
#include<ctime>
#include<cmath>
#include<set>
#include<map>
#define ll long long
#define M 200010
using namespace std;
int read() {
    int nm = 0, f = 1;
    char c = getchar();
    for(; !isdigit(c); c = getchar()) if(c == '-') f = -1;
    for(; isdigit(c); c = getchar()) nm = nm * 10 + c - '0';
    return nm * f;
}

int note[M], sz[M], cor[M], id[M];
vector<int>to[M], to1[M];
int n, q, ans;
void dfs(int now, int fa) {
    if(cor[now] != 0 && cor[now] != cor[fa]) ans++;
    for(int i = 0; i < to[now].size(); i++) {
        int vj = to[now][i];
        if(vj == fa) continue;
        dfs(vj, now);
    }
}

void del(int x) {
    for(int i = 0; i < to[x].size(); i++) {
        int vj = to[x][i];
        if(cor[vj] != cor[x]) ans--;
    }
}
void insert(int x) {
    for(int i = 0; i < to[x].size(); i++) {
        int vj = to[x][i];
        if(cor[vj] != cor[x]) ans++;
    }
}
int tot = 0, tot2 = 0;
int main() {
    n = read(), q = read();
    for(int i = 1; i <= n; i++) cor[i] = read(), sz[cor[i]]++, to1[cor[i]].push_back(i), id[i] = i, note[i] = i;
    for(int i = 1; i < n; i++) {
        int vi = read(), vj = read();
        to[vi].push_back(vj), to[vj].push_back(vi);
    }
    to[1].push_back(0), cor[0] = 0x3e3e3e3e;
    dfs(1, 0);
    while(q--) {
        int x = read(), y = read();
        int xn = id[x], yn = id[y];
        if(sz[xn] < sz[yn]) {
            tot += sz[xn], tot2 += to1[xn].size();
            for(int i = 0; i < to1[xn].size(); i++) {
                int op = to1[xn][i];
                del(op);
                to1[yn].push_back(op);
            }
            for(int i = 0; i < to1[xn].size(); i++) {
                int op = to1[xn][i];
                cor[op] = yn;
            }
            for(int i = 0; i < to1[xn].size(); i++) {
                int op = to1[xn][i];
                insert(op);
            }
            to1[xn].clear();
            sz[yn] += sz[xn];
            sz[xn] = 0;
            id[x] = 0;
        } else {
            tot+=sz[yn], tot2 += to1[yn].size();
            for(int i = 0; i < to1[yn].size(); i++) {
                int op = to1[yn][i];
                del(op);
                to1[xn].push_back(op);
            }
            for(int i = 0; i < to1[yn].size(); i++) {
                int op = to1[yn][i];
                cor[op] = xn;
            }
            for(int i = 0; i < to1[yn].size(); i++) {
                int op = to1[yn][i];
                insert(op);
            }
            to1[yn].clear();
            sz[xn] += sz[yn];
            sz[yn] = 0;
            id[y] = xn;
            id[x] = 0;
        }
        cout << ans << "\n";
    }
    return 0;
}