CF620E New Year Tree
题目地址:
基本思路:
我们发现颜色最多只有种,所以明显我们可以用二进制串来表示每种颜色,
并且因为操作中只有对子树的修改,所以我们直接序+线段树维护一下二进制状态就是了。
参考代码:
#pragma GCC optimize(2) #pragma GCC optimize(3) #include <bits/stdc++.h> using namespace std; #define IO std::ios::sync_with_stdio(false); cin.tie(0) #define ll long long #define ull unsigned long long #define SZ(x) ((int)(x).size()) #define all(x) (x).begin(), (x).end() #define rep(i, l, r) for (int i = l; i <= r; i++) #define per(i, l, r) for (int i = l; i >= r; i--) #define mset(s, _) memset(s, _, sizeof(s)) #define pb push_back #define pii pair <int, int> #define mp(a, b) make_pair(a, b) #define debug(x) cerr << #x << " = " << x << '\n'; #define pll pair <ll, ll> #define fir first #define sec second #define INF 0x3f3f3f3f #define int ll inline int read() { int x = 0, neg = 1; char op = getchar(); while (!isdigit(op)) { if (op == '-') neg = -1; op = getchar(); } while (isdigit(op)) { x = 10 * x + op - '0'; op = getchar(); } return neg * x; } inline void print(int x) { if (x < 0) { putchar('-'); x = -x; } if (x >= 10) print(x / 10); putchar(x % 10 + '0'); } const int maxn = 4e5 + 10; vector<int> G[maxn]; int n,c[maxn],q; int L[maxn],R[maxn],tot,dfn[maxn]; void dfs(int u,int par) { L[u] = ++tot; dfn[tot] = u; for (auto to : G[u]) { if (to == par) continue; dfs(to, u); } R[u] = tot; } #define ls (index << 1) #define rs (index << 1 | 1) struct SegmentTree { struct Node { int l, r, res, lazy; } tr[maxn * 4]; inline void push_up(int index) { tr[index].res = tr[ls].res | tr[rs].res; } inline void push_down(int index) { if (tr[index].lazy) { tr[ls].res = tr[rs].res = tr[ls].lazy = tr[rs].lazy = tr[index].lazy; tr[index].lazy = 0; } } void build(int index, int l, int r) { tr[index].l = l, tr[index].r = r; if (l == r) { tr[index].res = (1ll << c[dfn[l]]); return; } int mid = (l + r) >> 1; build(ls, l, mid); build(rs, mid + 1, r); push_up(index); } void change(int index, int l, int r, int val) { if (tr[index].l >= l && tr[index].r <= r) { tr[index].res = 1ll << val; tr[index].lazy = 1ll << val; return; } push_down(index); int mid = (tr[index].l + tr[index].r) >> 1; if (r <= mid) change(ls, l, r, val); else if (l > mid) change(rs, l, r, val); else { change(ls, l, mid, val); change(rs, mid + 1, r, val); } push_up(index); } int query(int index, int l, int r) { if (tr[index].l >= l && tr[index].r <= r) { return tr[index].res; } push_down(index); int mid = (tr[index].l + tr[index].r) >> 1; if (r <= mid) return query(ls, l, r); else if (l > mid) return query(rs, l, r); else return query(ls, l, mid) | query(rs, mid + 1, r); } }smt; signed main() { n = read(),q = read(); rep(i,1,n) c[i] = read(); rep(i,1,n - 1){ int u = read(),v = read(); G[u].pb(v); G[v].pb(u); } tot = 0; dfs(1,0); smt.build(1,1,tot); while (q--){ int op = read(); if(op == 1){ int u = read(),x = read(); smt.change(1,L[u],R[u],x); }else{ int u = read(); int now = smt.query(1,L[u],R[u]); int ans = 0; for(int i = 1 ; i <= 60 ; i++) if((now >> i) & 1) ans++; print(ans); puts(""); } } return 0; }