New Year Tree

首先看到颜色的范围是,容易想到用二进制数来存储,二进制的第i位表示是否有第i种颜色

接下来我们只要求得整棵树的序,然后维护一个异或线段树即可,支持区间修改,区间查询即可,
比较套路的裸题吧。

#include <bits/stdc++.h>
#define mid (l + r >> 1)
#define lson rt << 1, l, mid
#define rson rt << 1 | 1, mid + 1, r
#define ls rt << 1
#define rs rt << 1 | 1

using namespace std;

typedef long long ll;

const int N = 4e5 + 10;

int head[N], to[N << 1], nex[N << 1], cnt = 1;

int n, m, value[N], rk[N], l[N], r[N], tot;

ll tree[N << 2], lazy[N << 2];

void add(int x, int y) {
  to[cnt] = y;
  nex[cnt] = head[x];
  head[x] = cnt++;
}

void push_up(int rt) {
  tree[rt] = tree[ls] | tree[rs];
}

void push_down(int rt, int l, int r) {
  if (lazy[rt]) {
    tree[ls] = tree[rs] = lazy[ls] = lazy[rs] = lazy[rt];
    lazy[rt] = 0;
  }
}

void build(int rt, int l, int r) {
  if (l == r) {
    tree[rt] = 1ll << value[rk[l]];
    return ;
  }
  build(lson);
  build(rson);
  push_up(rt);
}

void update(int rt, int l, int r, int L, int R, ll value) {
  if (l >= L && r <= R) {
    tree[rt] = value;
    lazy[rt] = value;
    return ;
  }
  push_down(rt, l, r);
  if (L <= mid) {
    update(lson, L, R, value);
  }
  if (R > mid) {
    update(rson, L, R, value);
  }
  push_up(rt);
}

ll query(int rt, int l, int r, int L, int R) {
  if (l >= L && r <= R) {
    return tree[rt];
  }
  push_down(rt, l, r);
  ll ans = 0;
  if (L <= mid) {
    ans |= query(lson, L, R);
  }
  if (R > mid) {
    ans |= query(rson, L, R);
  }
  return ans;
}

void dfs(int rt, int fa) {
  l[rt] = ++tot, rk[tot] = rt;
  for (int i = head[rt]; i; i = nex[i]) {
    if (to[i] == fa) {
      continue;
    }
    dfs(to[i], rt);
  }
  r[rt] = tot;
}

int main() {
  // freopen("in.txt", "r", stdin);
  // freopen("out.txt", "w", stdout);
  // ios::sync_with_stdio(false), cin.tie(0), cout.tie(0);
  scanf("%d %d", &n, &m);
  for (int i = 1; i <= n; i++) {
    scanf("%d", &value[i]);
  }
  for (int i = 1; i < n; i++) {
    int x, y;
    scanf("%d %d", &x, &y);
    add(x, y);
    add(y, x);
  }
  dfs(1, 0);
  build(1, 1, n);
  for (int i = 1; i <= m; i++) {
    int op, v, c;
    scanf("%d %d", &op, &v);
    if (op & 1) {
      scanf("%d", &c);
      update(1, 1, n, l[v], r[v], 1ll << c);
    }
    else {
      ll ans = query(1, 1, n, l[v], r[v]);
      int res = 0;
      for (int i = 1; i <= 60; i++) {
        res += ans >> i & 1;
      }
      printf("%d\n", res);
    }
  }
  return 0;
}