一颗 个点的以为根的树 树上有一些松鼠 记点 上的松鼠数目为

进行若干次以下步骤直到所有都为

对于所有

对于所有 , (其中 表示 的儿子集合 下同)

求最后的

通过观察可以发现不同深度的点之间互不影响

所以我们可以对于每种深度的点分别考虑

建出虚树然后记子树内的所有点的权值走到的值是多少

转移比较容易想到 注意这里的是虚树上的儿子集合,而表示的是原树上的深度

复杂度

#include <bits/stdc++.h>
#define LL long long
using namespace std;
template <typename T> void read(T &x){
    x = 0; int f = 1; char ch = getchar();
    while (!isdigit(ch)) {if (ch == '-') f = -1; ch = getchar();}
    while (isdigit(ch)) {x = x * 10 + ch - '0'; ch = getchar();}
    x *= f;
}
inline void write(int x){if (x > 9) write(x/10); putchar(x%10+'0'); }

const int N = 400050;
int dpt[N],size[N],son[N],fa[N];

int To[N<<1],Ne[N<<1],He[N],_k = 0;
int n,rt;
inline void dfs1(int x){
    size[x] = 1;
    for (int p = He[x],y; p ; p = Ne[p]) if ((y=To[p])^fa[x]){
        fa[y] = x,dpt[y] = dpt[x] + 1,dfs1(y),size[x] += size[y];
        if (size[y] > size[son[x]]) son[x] = y;
    }
}
int top[N],Time,po[N],id[N]; 
inline void dfs2(int x){
    id[x] = ++Time; po[Time] = x;
    if (son[x]){
        top[son[x]] = top[x],dfs2(son[x]);
        for (int y,p = He[x]; p ; p = Ne[p]) if (!top[y=To[p]]) top[y] = y,dfs2(y);
    }
}
inline int LCA(int x,int y){
    while (top[x] ^ top[y]) if (dpt[top[x]] < dpt[top[y]]) y = fa[top[y]]; else x = fa[top[x]];
    return dpt[x] <= dpt[y] ? x : y;
}

vector<int>G[N];

int a[N];
vector<int>ch[N];
inline void add(int x,int y){
    ch[x].push_back(x),ch[y].push_back(x);
}
struct Ti{
    int st[N],top,col,tfa[N];
    inline void addd(int &x){
        if (!top){ st[++top] = x; return; }
        static int p,z; p = LCA(st[top],x);
        if (p == st[top]){ st[++top] = x; return; }
        while (p != st[top] && top){
            z = p;
            if (top > 1 && LCA(z,st[top-1]) == st[top-1] && z != st[top-1]){
                add(st[top],z);    st[top] = z; st[++top] = x; return;
            }
            if (top == 1){
                add(st[top],z); st[top] = z; st[++top] = x; return;
            }
            add(st[top],st[top-1]); --top; p = LCA(st[top],x);
        }
        st[++top] = x;
    }
    int pp[N],lenp;
    inline bool cmpd(int x,int y){ return id[x] < id[y]; }
    inline void Getit(vector<int> &v){
        static int i;
        top = 0;
        lenp = v.size(); for (i = 1; i <= lenp; ++i) pp[i] = id[v[i-1]];
        sort(pp+1,pp+lenp+1);
        for (i = 1; i <= lenp; ++i) pp[i] = po[pp[i]];
        for (i = 1; i <= lenp; ++i)
            addd(pp[i]);
        while (top > 1) add(st[top],st[top-1]),--top;
    }
}TT;
LL ans;
int nowC;
int p[N],cntp; bool vis[N];
inline LL dfs(int x){
    if (vis[x]) return 0; vis[x] = 1; p[++cntp] = x;
    LL now,v,d;
    now = 0; if (dpt[x] == nowC) now = a[x];
    for (int i = 0,y; i < ch[x].size(); ++i){
        y = ch[x][i]; if (vis[y]) continue;
        v = dfs(y),d = dpt[y] - dpt[x];
        if (!v) now += v;
        else if (v==1) ++now;
        else now += max(1ll,v-d);
    }
    return now;
}
int main(){
    int i,x,y; LL vv;
    read(n),read(rt);
    for (i = 1; i <= n; ++i) read(a[i]);
    for (i = 1; i < n; ++i){
        read(x),read(y);
        ++_k; To[_k] = y,Ne[_k] = He[x],He[x] = _k;
        ++_k; To[_k] = x,Ne[_k] = He[y],He[y] = _k;
    //    cerr<<"EDGE "<<x<<' '<<y<<'\n';
    }
    vv = a[rt],a[rt] = 0;
    ans = 0;
    if (!vv) ans=0; else if (vv==1)++ans; else ans+=vv-1;
    dpt[rt] = 1; top[rt] = rt; 
    dfs1(rt); dfs2(rt);
    for (i = 1; i <= n; ++i) if (i != rt) G[dpt[i]].push_back(i);
    for (i = 0; i <= n; ++i) if (G[i].size()){
        G[i].push_back(rt); TT.Getit(G[i]);
        nowC = i;
        cntp = 0; fa[rt] = -1,vis[rt] = 0;
        vv = dfs(rt);
        if (!vv) ans+=0;else if (vv>1) ans += vv-1; else ++ans;
        while (cntp){
            x = p[cntp],--cntp;
            ch[x].clear(),vis[x] = 0;
        }
    }
    cout << ans << '\n';
    return 0;
}