需要掌握线段树合并的基本知识。
显然只有同一深度的会合并,我们将每次合并后的线段树在头结点打一个递减标记、
每次合并需要将其标记下传,
最后在叶节点之间的合并中使用上这个函数,因为只有头结点统计答案,所以到头结点时标记全部下传即可。
#include <cstdio> #include <algorithm> #define re register #define ll long long #define int ll #define rep(i,a,b) for(re int i=a;i<=b;++i) #define per(i,a,b) for(re int i=a;i>=b;--i) using namespace std; template<typename T> inline void read(T&x) { x=0; char s=(char)getchar(); bool f=false; while(!(s>='0'&&s<='9')) { if(s=='-') f=true; s=(char)getchar(); } while(s>='0'&&s<='9') { x=(x<<1)+(x<<3)+s-'0'; s=(char)getchar(); } if(f) x=(~x)+1; } template<typename T,typename ...T1> inline void read(T&x,T1&...x1) { read(x); read(x1...); } const int N=2e5+5; struct Edge { int next,to; } edge[N<<1]; int head[N],num_edge; inline void add_edge(int from,int to) { edge[++num_edge].next=head[from]; edge[num_edge].to=to; head[from]=num_edge; } struct Tree { int l,r; ll size,tag; inline void calc() { if(!size) { tag=0; return; } size=max(1ll,size-tag); tag=0; } } tree[N*100]; int cnt; #define lc(x) tree[x].l #define rc(x) tree[x].r int n,s; inline void pushup(int x) { tree[x].size=tree[lc(x)].size+tree[rc(x)].size; } inline void pushdown(int x) { if(tree[x].tag) { if(lc(x)) tree[lc(x)].tag+=tree[x].tag; if(rc(x)) tree[rc(x)].tag+=tree[x].tag; tree[x].tag=0; } } inline int merge(int x,int y,int l=1,int r=n) { if(!x||!y) return x|y; if(l==r) { tree[x].calc(),tree[y].calc(); tree[x].size+=tree[y].size; return x; } int mid=(l+r)>>1; pushdown(x),pushdown(y); lc(x)=merge(lc(x),lc(y),l,mid); rc(x)=merge(rc(x),rc(y),mid+1,r); pushup(x); return x; } inline void update(int &rt,int l,int r,int pos,int val) { if(!rt) rt=++cnt; if(l==r) { tree[rt].calc(); tree[rt].size+=val; return; } int mid=(l+r)>>1; pushdown(rt); if(pos<=mid) update(lc(rt),l,mid,pos,val); else update(rc(rt),mid+1,r,pos,val); pushup(rt); } int dep[N],root[N],a[N]; inline void dfs(int u,int fa) { dep[u]=dep[fa]+1; for(re int i=head[u]; i; i=edge[i].next) { int &v=edge[i].to; if(v==fa) continue; dfs(v,u); root[u]=merge(root[u],root[v]); } update(root[u],1,n,dep[u],a[u]); ++tree[root[u]].tag; // printf("%d %lld\n",u,tree[root[u]].size); } inline void build(int rt,int l,int r) { if(!rt) return; if(l==r) { tree[rt].calc(); return; } int mid=(l+r)>>1; pushdown(rt); build(lc(rt),l,mid); build(rc(rt),mid+1,r); pushup(rt); } signed main() { read(n,s); for(re int i=1; i<=n; ++i) read(a[i]); for(re int i=1; i^n; ++i) { int u,v; read(u,v); add_edge(u,v); add_edge(v,u); } // printf("%d\n",num_edge); dfs(s,0); build(root[s],1,n); printf("%lld\n",tree[root[s]].size); return 0; }