一颗 个点的以
为根的树
树上有一些松鼠
记点
上的松鼠数目为
进行若干次以下步骤直到所有都为
对于所有
令
对于所有
令
, (其中
表示
的儿子集合
下同)
求最后的
通过观察可以发现不同深度的点之间互不影响
所以我们可以对于每种深度的点分别考虑
建出虚树然后记
为
子树内的所有点的权值走到
时
的值是多少
转移比较容易想到
注意这里的
是虚树上的儿子集合,而
表示的是原树上的深度
复杂度
#include <bits/stdc++.h>
#define LL long long
using namespace std;
template <typename T> void read(T &x){
x = 0; int f = 1; char ch = getchar();
while (!isdigit(ch)) {if (ch == '-') f = -1; ch = getchar();}
while (isdigit(ch)) {x = x * 10 + ch - '0'; ch = getchar();}
x *= f;
}
inline void write(int x){if (x > 9) write(x/10); putchar(x%10+'0'); }
const int N = 400050;
int dpt[N],size[N],son[N],fa[N];
int To[N<<1],Ne[N<<1],He[N],_k = 0;
int n,rt;
inline void dfs1(int x){
size[x] = 1;
for (int p = He[x],y; p ; p = Ne[p]) if ((y=To[p])^fa[x]){
fa[y] = x,dpt[y] = dpt[x] + 1,dfs1(y),size[x] += size[y];
if (size[y] > size[son[x]]) son[x] = y;
}
}
int top[N],Time,po[N],id[N];
inline void dfs2(int x){
id[x] = ++Time; po[Time] = x;
if (son[x]){
top[son[x]] = top[x],dfs2(son[x]);
for (int y,p = He[x]; p ; p = Ne[p]) if (!top[y=To[p]]) top[y] = y,dfs2(y);
}
}
inline int LCA(int x,int y){
while (top[x] ^ top[y]) if (dpt[top[x]] < dpt[top[y]]) y = fa[top[y]]; else x = fa[top[x]];
return dpt[x] <= dpt[y] ? x : y;
}
vector<int>G[N];
int a[N];
vector<int>ch[N];
inline void add(int x,int y){
ch[x].push_back(x),ch[y].push_back(x);
}
struct Ti{
int st[N],top,col,tfa[N];
inline void addd(int &x){
if (!top){ st[++top] = x; return; }
static int p,z; p = LCA(st[top],x);
if (p == st[top]){ st[++top] = x; return; }
while (p != st[top] && top){
z = p;
if (top > 1 && LCA(z,st[top-1]) == st[top-1] && z != st[top-1]){
add(st[top],z); st[top] = z; st[++top] = x; return;
}
if (top == 1){
add(st[top],z); st[top] = z; st[++top] = x; return;
}
add(st[top],st[top-1]); --top; p = LCA(st[top],x);
}
st[++top] = x;
}
int pp[N],lenp;
inline bool cmpd(int x,int y){ return id[x] < id[y]; }
inline void Getit(vector<int> &v){
static int i;
top = 0;
lenp = v.size(); for (i = 1; i <= lenp; ++i) pp[i] = id[v[i-1]];
sort(pp+1,pp+lenp+1);
for (i = 1; i <= lenp; ++i) pp[i] = po[pp[i]];
for (i = 1; i <= lenp; ++i)
addd(pp[i]);
while (top > 1) add(st[top],st[top-1]),--top;
}
}TT;
LL ans;
int nowC;
int p[N],cntp; bool vis[N];
inline LL dfs(int x){
if (vis[x]) return 0; vis[x] = 1; p[++cntp] = x;
LL now,v,d;
now = 0; if (dpt[x] == nowC) now = a[x];
for (int i = 0,y; i < ch[x].size(); ++i){
y = ch[x][i]; if (vis[y]) continue;
v = dfs(y),d = dpt[y] - dpt[x];
if (!v) now += v;
else if (v==1) ++now;
else now += max(1ll,v-d);
}
return now;
}
int main(){
int i,x,y; LL vv;
read(n),read(rt);
for (i = 1; i <= n; ++i) read(a[i]);
for (i = 1; i < n; ++i){
read(x),read(y);
++_k; To[_k] = y,Ne[_k] = He[x],He[x] = _k;
++_k; To[_k] = x,Ne[_k] = He[y],He[y] = _k;
// cerr<<"EDGE "<<x<<' '<<y<<'\n';
}
vv = a[rt],a[rt] = 0;
ans = 0;
if (!vv) ans=0; else if (vv==1)++ans; else ans+=vv-1;
dpt[rt] = 1; top[rt] = rt;
dfs1(rt); dfs2(rt);
for (i = 1; i <= n; ++i) if (i != rt) G[dpt[i]].push_back(i);
for (i = 0; i <= n; ++i) if (G[i].size()){
G[i].push_back(rt); TT.Getit(G[i]);
nowC = i;
cntp = 0; fa[rt] = -1,vis[rt] = 0;
vv = dfs(rt);
if (!vv) ans+=0;else if (vv>1) ans += vv-1; else ++ans;
while (cntp){
x = p[cntp],--cntp;
ch[x].clear(),vis[x] = 0;
}
}
cout << ans << '\n';
return 0;
} 
京公网安备 11010502036488号