CF620E New Year Tree

题目地址:

https://ac.nowcoder.com/acm/problem/111259

基本思路:

我们发现颜色最多只有种,所以明显我们可以用二进制串来表示每种颜色,

并且因为操作中只有对子树的修改,所以我们直接序+线段树维护一下二进制状态就是了。

参考代码:

#pragma GCC optimize(2)
#pragma GCC optimize(3)
#include <bits/stdc++.h>
using namespace std;
#define IO std::ios::sync_with_stdio(false); cin.tie(0)
#define ll long long
#define ull unsigned long long
#define SZ(x) ((int)(x).size())
#define all(x) (x).begin(), (x).end()
#define rep(i, l, r) for (int i = l; i <= r; i++)
#define per(i, l, r) for (int i = l; i >= r; i--)
#define mset(s, _) memset(s, _, sizeof(s))
#define pb push_back
#define pii pair <int, int>
#define mp(a, b) make_pair(a, b)
#define debug(x) cerr << #x << " = " << x << '\n';
#define pll pair <ll, ll>
#define fir first
#define sec second
#define INF 0x3f3f3f3f
#define int ll

inline int read() {
  int x = 0, neg = 1; char op = getchar();
  while (!isdigit(op)) { if (op == '-') neg = -1; op = getchar(); }
  while (isdigit(op)) { x = 10 * x + op - '0'; op = getchar(); }
  return neg * x;
}
inline void print(int x) {
  if (x < 0) { putchar('-'); x = -x; }
  if (x >= 10) print(x / 10);
  putchar(x % 10 + '0');
}

const int maxn = 4e5 + 10;
vector<int> G[maxn];
int n,c[maxn],q;
int L[maxn],R[maxn],tot,dfn[maxn];
void dfs(int u,int par) {
  L[u] = ++tot;
  dfn[tot] = u;
  for (auto to : G[u]) {
    if (to == par) continue;
    dfs(to, u);
  }
  R[u] = tot;
}
#define ls (index << 1)
#define rs (index << 1 | 1)
struct SegmentTree {

    struct Node {
        int l, r, res, lazy;
    } tr[maxn * 4];

    inline void push_up(int index) {
      tr[index].res = tr[ls].res | tr[rs].res;
    }
    inline void push_down(int index) {
      if (tr[index].lazy) {
        tr[ls].res = tr[rs].res = tr[ls].lazy = tr[rs].lazy = tr[index].lazy;
        tr[index].lazy = 0;
      }
    }
    void build(int index, int l, int r) {
      tr[index].l = l, tr[index].r = r;
      if (l == r) {
        tr[index].res = (1ll << c[dfn[l]]);
        return;
      }
      int mid = (l + r) >> 1;
      build(ls, l, mid);
      build(rs, mid + 1, r);
      push_up(index);
    }
    void change(int index, int l, int r, int val) {
      if (tr[index].l >= l && tr[index].r <= r) {
        tr[index].res = 1ll << val;
        tr[index].lazy = 1ll << val;
        return;
      }
      push_down(index);
      int mid = (tr[index].l + tr[index].r) >> 1;
      if (r <= mid) change(ls, l, r, val);
      else if (l > mid) change(rs, l, r, val);
      else {
        change(ls, l, mid, val);
        change(rs, mid + 1, r, val);
      }
      push_up(index);
    }
    int query(int index, int l, int r) {
      if (tr[index].l >= l && tr[index].r <= r) {
        return tr[index].res;
      }
      push_down(index);
      int mid = (tr[index].l + tr[index].r) >> 1;
      if (r <= mid) return query(ls, l, r);
      else if (l > mid) return query(rs, l, r);
      else return query(ls, l, mid) | query(rs, mid + 1, r);
    }
}smt;

signed main() {
  n = read(),q = read();
  rep(i,1,n) c[i] = read();
  rep(i,1,n - 1){
    int u = read(),v = read();
    G[u].pb(v);
    G[v].pb(u);
  }
  tot = 0;
  dfs(1,0);
  smt.build(1,1,tot);
  while (q--){
    int op = read();
    if(op == 1){
      int u = read(),x = read();
      smt.change(1,L[u],R[u],x);
    }else{
      int u = read();
      int now = smt.query(1,L[u],R[u]);
      int ans = 0;
      for(int i = 1 ; i <= 60 ; i++) if((now >> i) & 1) ans++;
      print(ans); puts("");
    }
  }
  return 0;
}