稍微调整了一下,感觉看起来比原题解清晰一些,注意外层sz和HLD里的sz不同

#include <bits/stdc++.h>
using namespace std;
#define rep(i, a, b) for(int i = a; i <= b; ++i)
#define per(i, a, b) for(int i = a; i >= b; --i)
const int N = 2e5 + 5;
vector<int> g[N];
int col[N];
int sz[N], ans[N], tp[N];
int n;
namespace seg{
    int mn[N << 2], lz[N << 2];
    #define ls p << 1
    #define rs p << 1 | 1
    #define MID int mid = (l + r) >> 1
    void addtag(int p, int d){
        lz[p] += d;
        mn[p] += d;
    }
    void pushdown(int p){
        if(lz[p]){
            addtag(ls, lz[p]);
            addtag(rs, lz[p]);
            lz[p] = 0;
        }
    }
    void up(int p){
        mn[p] = min(mn[ls], mn[rs]);
    }
    void build(int l, int r, int p){
        if(l == r){
            mn[p] = sz[tp[l]];
            return;
        }
        MID;
        build(l, mid, ls); build(mid + 1, r, rs);
        up(p);
    }
    void add(int l, int r, int L, int R, int d, int p){
        if(L > r || R < l) return;
        if(L <= l && r <= R){
            addtag(p, d);
            return;
        }
        MID;
        pushdown(p);
        add(l, mid, L, R, d, ls); add(mid + 1, r, L, R, d, rs);
        up(p);
    }
    int find(int l, int r, int L, int R, int d, int p){
        if(mn[p] > d) return 0;
        if(l == r) return l;
        pushdown(p);
        MID;
        if(L == l && r == R){
            if(mn[rs] <= d) return find(mid + 1, r, mid + 1, R, d, rs);
            else return find(l, mid, l, mid, d, ls);
        }
        if(L > mid) return find(mid + 1, r, L, R, d, rs);
        else if(R <= mid) return find(l, mid, L, R, d, ls);
        else{
            int res = find(mid + 1, r, mid + 1, R, d, rs);
            if(!res) return find(l, mid, L, mid, d, ls);
            else return res;
        }
    }
}
namespace HLD{
    int sz[N], son[N], dep[N], top[N], dfn[N], cnt, fa[N];
    void dfs1(int x){
        sz[x] = 1;
        for(int y : g[x]){
            if(y == fa[x]) continue;
            fa[y] = x;
            dfs1(y);
            sz[x] += sz[y];
            if(sz[y] > sz[son[x]]) son[x] = y;
        }
    }
    void dfs2(int x, int f){
        top[x] = f;
        dfn[x] = ++cnt;
        tp[cnt] = x;
        if(son[x]) dfs2(son[x], f);
        for(int y : g[x]){
            if(y == son[x] || y == fa[x]) continue;
            dfs2(y, y);
        }
    }
    void turnW(int x, int &res){
        while(x){
            int l = dfn[top[x]], r = dfn[x];
            int dfnpos = seg::find(1, n, l, r, 1, 1);
            if(dfnpos){
                seg::add(1, n, dfnpos, r, -2, 1);
                x = fa[tp[dfnpos]];
                ++res;
                break;
            }
            seg::add(1, n, l, r, -2, 1);
            x = fa[top[x]];
        }
        while(x){
            int l = dfn[top[x]], r = dfn[x];
            int dfnpos = seg::find(1, n, l, r, 0, 1);
            if(dfnpos){
                seg::add(1, n, dfnpos, r, -1, 1);
                ++res;
                break;
            }
            seg::add(1, n, l, r, -1, 1);
            x = fa[top[x]];
        }
    }
    void turnB(int x, int &res){
        while(x){
            int l = dfn[top[x]], r = dfn[x];
            int dfnpos = seg::find(1, n, l, r, -1, 1);
            if(dfnpos){
                seg::add(1, n, dfnpos, r, 2, 1);
                x = fa[tp[dfnpos]];
                --res;
                break;
            }
            seg::add(1, n, l, r, 2, 1);
            x = fa[top[x]];
        }
        while(x){
            int l = dfn[top[x]], r = dfn[x];
            int dfnpos = seg::find(1, n, l, r, -1, 1);
            if(dfnpos){
                seg::add(1, n, dfnpos, r, 1, 1);
                --res;
                break;
            }
            seg::add(1, n, l, r, 1, 1);
            x = fa[top[x]];
        }
    }
}
void dfs(int x){
    if(col[x] == 0) sz[x] = -1;
    else sz[x] = 1;
    for(int y : g[x]){
        if(y == HLD::fa[x]) continue;
        dfs(y);
        if(sz[y] > 0) sz[x] += sz[y];
        ans[x] += ans[y];
    }
    if(sz[x] < 0) ++ans[x];
}
int main(){
    ios::sync_with_stdio(false); cin.tie(0); cout.tie(0);
    int q;
    cin >> n >> q;
    rep(i, 1, n) cin >> col[i];
    rep(i, 2, n){
        int u, v;
        cin >> u >> v;
        g[u].push_back(v);
        g[v].push_back(u);
    }
    HLD::dfs1(1);
    HLD::dfs2(1, 1);
    dfs(1);
    seg::build(1, n, 1);
    int res = ans[1];
    while(q--){
        int x, c;
        cin >> x >> c;
        if(col[x] == c){
            cout << res << '\n';
            continue;
        }
        col[x] = c;
        if(c == 0){
            HLD::turnW(x, res);
        }else{
            HLD::turnB(x, res);
        }
        cout << res << '\n';
    }
    return 0;
}