需要掌握线段树合并的基本知识。

显然只有同一深度的会合并,我们将每次合并后的线段树在头结点打一个递减标记、

每次合并需要将其标记下传,图片说明
最后在叶节点之间的合并中使用上这个函数,因为只有头结点统计答案,所以到头结点时标记全部下传即可。

#include <cstdio>
#include <algorithm>
#define re register
#define ll long long
#define int ll
#define rep(i,a,b) for(re int i=a;i<=b;++i)
#define per(i,a,b) for(re int i=a;i>=b;--i)
using namespace std;
template<typename T>
inline void read(T&x)
{
    x=0;
    char s=(char)getchar();
    bool f=false;
    while(!(s>='0'&&s<='9'))
    {
        if(s=='-')
            f=true;
        s=(char)getchar();
    }
    while(s>='0'&&s<='9')
    {
        x=(x<<1)+(x<<3)+s-'0';
        s=(char)getchar();
    }
    if(f)
        x=(~x)+1;
}
template<typename T,typename ...T1>
inline void read(T&x,T1&...x1)
{
    read(x);
    read(x1...);
}
const int N=2e5+5;
struct Edge
{
    int next,to;
} edge[N<<1];
int head[N],num_edge;
inline void add_edge(int from,int to)
{
    edge[++num_edge].next=head[from];
    edge[num_edge].to=to;
    head[from]=num_edge;
}
struct Tree
{
    int l,r;
    ll size,tag;
    inline void calc()
    {
        if(!size)
        {
            tag=0;
            return;
        }
        size=max(1ll,size-tag);
        tag=0;
    }
} tree[N*100];
int cnt;
#define lc(x) tree[x].l
#define rc(x) tree[x].r
int n,s;
inline void pushup(int x)
{
    tree[x].size=tree[lc(x)].size+tree[rc(x)].size;
}
inline void pushdown(int x)
{
    if(tree[x].tag)
    {
        if(lc(x))
            tree[lc(x)].tag+=tree[x].tag;
        if(rc(x))
            tree[rc(x)].tag+=tree[x].tag;
        tree[x].tag=0;
    }
}
inline int merge(int x,int y,int l=1,int r=n)
{
    if(!x||!y)
        return x|y;
    if(l==r)
    {
        tree[x].calc(),tree[y].calc();
        tree[x].size+=tree[y].size;
        return x;
    }
    int mid=(l+r)>>1;
    pushdown(x),pushdown(y);
    lc(x)=merge(lc(x),lc(y),l,mid);
    rc(x)=merge(rc(x),rc(y),mid+1,r);
    pushup(x);
    return x;
}
inline void update(int &rt,int l,int r,int pos,int val)
{
    if(!rt)
        rt=++cnt;
    if(l==r)
    {
        tree[rt].calc();
        tree[rt].size+=val;
        return;
    }
    int mid=(l+r)>>1;
    pushdown(rt);
    if(pos<=mid)
        update(lc(rt),l,mid,pos,val);
    else
        update(rc(rt),mid+1,r,pos,val);
    pushup(rt);
}
int dep[N],root[N],a[N];
inline void dfs(int u,int fa)
{
    dep[u]=dep[fa]+1;
    for(re int i=head[u]; i; i=edge[i].next)
    {
        int &v=edge[i].to;
        if(v==fa)
            continue;
        dfs(v,u);
        root[u]=merge(root[u],root[v]);
    }
    update(root[u],1,n,dep[u],a[u]);
    ++tree[root[u]].tag;
//    printf("%d %lld\n",u,tree[root[u]].size);
}
inline void build(int rt,int l,int r)
{
    if(!rt)
        return;
    if(l==r)
    {
        tree[rt].calc();
        return;
    }
    int mid=(l+r)>>1;
    pushdown(rt);
    build(lc(rt),l,mid);
    build(rc(rt),mid+1,r);
    pushup(rt);
}
signed main()
{
    read(n,s);
    for(re int i=1; i<=n; ++i)
        read(a[i]);
    for(re int i=1; i^n; ++i)
    {
        int u,v;
        read(u,v);
        add_edge(u,v);
        add_edge(v,u);
    }
//    printf("%d\n",num_edge);
    dfs(s,0);
    build(root[s],1,n);
    printf("%lld\n",tree[root[s]].size);
    return 0;
}