题目描述

给你的一棵树,每个节点有自己的颜色。下面有两种操作。
操作一:把一个节点及其子树全部改成颜色
操作二:输出一个节点及它子树中一共有几种颜色。题目涉及的颜色种类

Solution

观察颜色种类就是这个题目的切入点,看到最大不超过种的颜色,就在指引我们去状压这个颜色了。
那么维护数上子节点,这个要联想到树链剖分的思路,把一棵树变成一条链,对这个链操作的话,就会有很多新的方法,比如说线段树。
这个题目就是线段树+状压统计即可,使用标记去记录赋值标记。每次节点都是赋值,到了这个时候就可以看出是和区间赋值的板子一样了。最后上传标记使用或运算在统计有几个二进制一即可出答案。注意问题就变成了,拿到序,把树变成链,在进行区间赋值操作,问区间有几个不同的数题目。

#include <bits/stdc++.h>
using namespace std;
#define js ios::sync_with_stdio(false);cin.tie(0); cout.tie(0)
#define all(__vv__) (__vv__).begin(), (__vv__).end()
#define endl "\n"
#define pai pair<int, int>
#define ms(__x__,__val__) memset(__x__, __val__, sizeof(__x__))
#define rep(i, sta, en) for(int i=sta; i<=en; ++i)
typedef long long ll; typedef unsigned long long ull; typedef long double ld;
inline ll read() { ll s = 0, w = 1; char ch = getchar(); for (; !isdigit(ch); ch = getchar()) if (ch == '-') w = -1; for (; isdigit(ch); ch = getchar())    s = (s << 1) + (s << 3) + (ch ^ 48); return s * w; }
inline void print(ll x, int op = 10) { if (!x) { putchar('0'); if (op)    putchar(op); return; }    char F[40]; ll tmp = x > 0 ? x : -x;    if (x < 0)putchar('-');    int cnt = 0;    while (tmp > 0) { F[cnt++] = tmp % 10 + '0';        tmp /= 10; }    while (cnt > 0)putchar(F[--cnt]);    if (op)    putchar(op); }
inline ll gcd(ll x, ll y) { return y ? gcd(y, x % y) : x; }
ll qpow(ll a, ll b) { ll ans = 1;    while (b) { if (b & 1)    ans *= a;        b >>= 1;        a *= a; }    return ans; }    ll qpow(ll a, ll b, ll mod) { ll ans = 1; while (b) { if (b & 1)(ans *= a) %= mod; b >>= 1; (a *= a) %= mod; }return ans % mod; }
const int dir[][2] = { {0,1},{1,0},{0,-1},{-1,0},{1,1},{1,-1},{-1,1},{-1,-1} };
const int MOD = 1e9 + 7;
const int INF = 0x3f3f3f3f;
const int N = 4e5 + 7;

/*前向星部分*/
int head[N], tot1;
struct Node {
    int v, next;
}edge[N << 1];
void add(int u, int v) {
    edge[++tot1].v = v;
    edge[tot1].next = head[u];
    head[u] = tot1;
}

/*dfs序部分,化树为链*/
int sz[N], dfn[N], dfn_id[N], tot2;
void dfs(int u, int fa) {
    sz[u] = 1;
    dfn[u] = ++tot2, dfn_id[tot2] = u;
    for (int i = head[u]; i; i = edge[i].next) {
        int v = edge[i].v;
        if (v == fa)    continue;
        dfs(v, u);
        sz[u] += sz[v];
    }
}

/*线段树部分、状压*/
#define mid (l + r >> 1)
#define lson rt << 1, l, mid
#define rson rt << 1 | 1, mid + 1, r
ll a[N], sum[N << 2], lazy[N << 2];

void push_down(int rt, int l, int r) {
    if (lazy[rt]) {
        sum[rt << 1] = lazy[rt << 1] = lazy[rt];
        sum[rt << 1 | 1] = lazy[rt << 1 | 1] = lazy[rt];
        lazy[rt] = 0;
    }
}
void push_up(int rt, int l, int r) {
    sum[rt] = sum[rt << 1] | sum[rt << 1 | 1];
}
void build(int rt, int l, int r) {
    if (l == r) {
        sum[rt] = 1ll << a[dfn_id[l]];
        lazy[rt] = 0;
        return;
    }
    build(lson);
    build(rson);
    push_up(rt, l, r);
}
void update(int rt, int l, int r, int L, int R, ll z) {
    if (L <= l and r <= R) {
        sum[rt] = lazy[rt] = z;
        return;
    }
    push_down(rt, l, r);
    if (L <= mid)    update(lson, L, R, z);
    if (R > mid)    update(rson, L, R, z);
    push_up(rt, l, r);
}
ll query(int rt, int l, int r, int L, int R) {
    if (L <= l and r <= R)    return sum[rt];
    push_down(rt, l, r);
    ll res = 0;
    if (L <= mid)    res |= query(lson, L, R);
    if (R > mid)    res |= query(rson, L, R);
    return res;
}

int calc(ll x) {
    int res = 0;
    for (int i = 1; i <= 60; ++i)
        if (x & (1ll << i))    ++res;
    return res;
}

void solve() {
    int n = read(), m = read();
    for (int i = 1; i <= n; ++i)    a[i] = read();
    for (int i = 1; i < n; ++i) {
        int u = read(), v = read();
        add(u, v), add(v, u);
    }
    dfs(1, -1);
    build(1, 1, n);
    while (m--) {
        int op = read();
        if (op & 1) {
            int u = read(), z = read();
            update(1, 1, n, dfn[u], dfn[u] + sz[u] - 1, 1ll << z);
        }
        else {
            int u = read();
            ll ans = query(1, 1, n, dfn[u], dfn[u] + sz[u] - 1);
            int cnt = calc(ans);
            print(cnt);
        }
    }
}

int main() {
    //int T = read();    while (T--)
    solve();
    return 0;
}