好难写....完全想不到,也不会证明
抄的https://ac.nowcoder.com/acm/problem/blogs/200179
好厉害呜呜
#include <bits/stdc++.h> using namespace std; typedef long long ll; const int maxn = 1e5+7; int n,a[maxn],ans[maxn]; vector<int>g[maxn]; set<int>s[maxn]; int deep[maxn],fa[maxn][20],dfn[maxn],tag,in[maxn]; void dfs(int u,int father) { dfn[++tag] = u,in[u] = tag; deep[u] = deep[father] + 1, fa[u][0] = father; for(int i=1;i<20;i++) fa[u][i] = fa[fa[u][i-1]][i-1]; for(int i=0;i<g[u].size();i++) if( g[u][i]!=father ) dfs(g[u][i],u ); } int lca(int x,int y) { if( deep[x]<deep[y] ) swap(x,y); for(int i=19;i>=0;i--) if( deep[fa[x][i]]>=deep[y] ) x = fa[x][i]; if( x==y ) return x; for(int i=19;i>=0;i--) if( fa[x][i]!=fa[y][i] ) x = fa[x][i], y = fa[y][i]; return fa[x][0]; } int dis(int u,int v) { u = dfn[u],v = dfn[v]; return deep[u]+deep[v]-2*deep[lca(u,v)]; } void add(int x,int c) { if( s[c].size()==0 ) { s[c].insert(x); ans[c] = 0; return; } auto it = s[c].lower_bound(x); if( it==s[c].begin() || it==s[c].end() )//在第一个或者根本没有比自己大的 { auto y = s[c].begin(); auto z = s[c].rbegin(); ans[c] += ( dis(x,*y)+dis(x,*z)-dis(*y,*z) )/2; } else { auto y = it, z = it; y--; ans[c] += ( dis(x,*y)+dis(x,*z)-dis(*y,*z) )/2; } s[c].insert(x); } void del(int x,int c) { if( s[c].size()==1 ) { s[c].erase(x); ans[c] = -1; return; } s[c].erase(x); auto it = s[c].lower_bound(x); if( it==s[c].begin() || it==s[c].end() )//在第一个或者根本没有比自己大的 { auto y = s[c].begin(); auto z = s[c].rbegin(); ans[c] -= ( dis(x,*y)+dis(x,*z)-dis(*y,*z) )/2; } else { auto y = it, z = it; y--; ans[c] -= ( dis(x,*y)+dis(x,*z)-dis(*y,*z) )/2; } } int main() { scanf("%d",&n); for(int i=1;i<n;i++) { int l,r; scanf("%d%d",&l,&r); g[l].push_back(r); g[r].push_back(l); } dfs(1,0); memset( ans,-1,sizeof(ans) ); for(int i=1;i<=n;i++) { scanf("%d",&a[i]); add( in[i],a[i] ); } int m; scanf("%d",&m); string s; while( m-- ) { cin >> s; if( s[0]=='U') { int u,c; scanf("%d%d",&u,&c); del( in[u],a[u] ); a[u] = c; add( in[u],a[u] ); } else { int c; scanf("%d",&c); printf( "%d\n",ans[c] ); } } }