稍微调整了一下,感觉看起来比原题解清晰一些,注意外层sz和HLD里的sz不同
#include <bits/stdc++.h>
using namespace std;
#define rep(i, a, b) for(int i = a; i <= b; ++i)
#define per(i, a, b) for(int i = a; i >= b; --i)
const int N = 2e5 + 5;
vector<int> g[N];
int col[N];
int sz[N], ans[N], tp[N];
int n;
namespace seg{
int mn[N << 2], lz[N << 2];
#define ls p << 1
#define rs p << 1 | 1
#define MID int mid = (l + r) >> 1
void addtag(int p, int d){
lz[p] += d;
mn[p] += d;
}
void pushdown(int p){
if(lz[p]){
addtag(ls, lz[p]);
addtag(rs, lz[p]);
lz[p] = 0;
}
}
void up(int p){
mn[p] = min(mn[ls], mn[rs]);
}
void build(int l, int r, int p){
if(l == r){
mn[p] = sz[tp[l]];
return;
}
MID;
build(l, mid, ls); build(mid + 1, r, rs);
up(p);
}
void add(int l, int r, int L, int R, int d, int p){
if(L > r || R < l) return;
if(L <= l && r <= R){
addtag(p, d);
return;
}
MID;
pushdown(p);
add(l, mid, L, R, d, ls); add(mid + 1, r, L, R, d, rs);
up(p);
}
int find(int l, int r, int L, int R, int d, int p){
if(mn[p] > d) return 0;
if(l == r) return l;
pushdown(p);
MID;
if(L == l && r == R){
if(mn[rs] <= d) return find(mid + 1, r, mid + 1, R, d, rs);
else return find(l, mid, l, mid, d, ls);
}
if(L > mid) return find(mid + 1, r, L, R, d, rs);
else if(R <= mid) return find(l, mid, L, R, d, ls);
else{
int res = find(mid + 1, r, mid + 1, R, d, rs);
if(!res) return find(l, mid, L, mid, d, ls);
else return res;
}
}
}
namespace HLD{
int sz[N], son[N], dep[N], top[N], dfn[N], cnt, fa[N];
void dfs1(int x){
sz[x] = 1;
for(int y : g[x]){
if(y == fa[x]) continue;
fa[y] = x;
dfs1(y);
sz[x] += sz[y];
if(sz[y] > sz[son[x]]) son[x] = y;
}
}
void dfs2(int x, int f){
top[x] = f;
dfn[x] = ++cnt;
tp[cnt] = x;
if(son[x]) dfs2(son[x], f);
for(int y : g[x]){
if(y == son[x] || y == fa[x]) continue;
dfs2(y, y);
}
}
void turnW(int x, int &res){
while(x){
int l = dfn[top[x]], r = dfn[x];
int dfnpos = seg::find(1, n, l, r, 1, 1);
if(dfnpos){
seg::add(1, n, dfnpos, r, -2, 1);
x = fa[tp[dfnpos]];
++res;
break;
}
seg::add(1, n, l, r, -2, 1);
x = fa[top[x]];
}
while(x){
int l = dfn[top[x]], r = dfn[x];
int dfnpos = seg::find(1, n, l, r, 0, 1);
if(dfnpos){
seg::add(1, n, dfnpos, r, -1, 1);
++res;
break;
}
seg::add(1, n, l, r, -1, 1);
x = fa[top[x]];
}
}
void turnB(int x, int &res){
while(x){
int l = dfn[top[x]], r = dfn[x];
int dfnpos = seg::find(1, n, l, r, -1, 1);
if(dfnpos){
seg::add(1, n, dfnpos, r, 2, 1);
x = fa[tp[dfnpos]];
--res;
break;
}
seg::add(1, n, l, r, 2, 1);
x = fa[top[x]];
}
while(x){
int l = dfn[top[x]], r = dfn[x];
int dfnpos = seg::find(1, n, l, r, -1, 1);
if(dfnpos){
seg::add(1, n, dfnpos, r, 1, 1);
--res;
break;
}
seg::add(1, n, l, r, 1, 1);
x = fa[top[x]];
}
}
}
void dfs(int x){
if(col[x] == 0) sz[x] = -1;
else sz[x] = 1;
for(int y : g[x]){
if(y == HLD::fa[x]) continue;
dfs(y);
if(sz[y] > 0) sz[x] += sz[y];
ans[x] += ans[y];
}
if(sz[x] < 0) ++ans[x];
}
int main(){
ios::sync_with_stdio(false); cin.tie(0); cout.tie(0);
int q;
cin >> n >> q;
rep(i, 1, n) cin >> col[i];
rep(i, 2, n){
int u, v;
cin >> u >> v;
g[u].push_back(v);
g[v].push_back(u);
}
HLD::dfs1(1);
HLD::dfs2(1, 1);
dfs(1);
seg::build(1, n, 1);
int res = ans[1];
while(q--){
int x, c;
cin >> x >> c;
if(col[x] == c){
cout << res << '\n';
continue;
}
col[x] = c;
if(c == 0){
HLD::turnW(x, res);
}else{
HLD::turnB(x, res);
}
cout << res << '\n';
}
return 0;
}