虚树入门
前置知识
掌握求lca,dfs序,每个点的深度,两点之间的距离等操作
什么是虚树
虚树的核心思想,就是保留一棵树上有用的&关键的点,重新构建一棵“虚”树,更加快速地跑带关键点的树形dp。
怎么用虚树
这里就需要用到上面提到的求lca(最近公共祖先)技术。
我们对一棵虚树的要求是,节点总数最小(没有多余的累赘),包含题目指定的节点,以及它们的lca。稍加思索即可明白,在这种情况下节点会被最小化。
(因为比较易懂,就不画图了懒得画)
首先我们要在树上跑dfs,求出前缀的dfs序,再以dfs序来从小到大排序;同时我们维护一个栈,表示从树根到栈顶元素的这条链。
得到关键点的时,按照dfn值(dfs序)排序,从前往后扫描,应用栈中的信息,先将这颗虚树build()好。
具体如何操作呢?
设我们的栈为s,关键点为x,栈顶指针为y。
- 如果栈为空或栈中只有一个元素,就应该s[++y]=x,将关键点压在栈顶;
- 我们找出x与s[y]的lca,如果lca=s[y],说明x点接着s[y]点,延长树链。所以我们还是做s[++y]=x;
- 反之,如果lca≠s[y],说明x与s[y]分属lca的不同子树,而且虚树上包含s[y]的这颗子树已经完成:
因此,我们需要将lca包含s[y]子树的那部分退栈,并将这部分建边形成虚树。
如果lca不在栈中,那么就把lca压入栈。接着延长树链,将x加入栈中。
在每个点做一次,我们的虚树就构建好了,求链上最小值,接着就是树形dp的操作。
模板核心代码
void find(int x,int y,int z)
{
int o,i;
a[x]=++t,c[x]=z,f[x][0]=y;
for(i=1;i<20;i++) f[x][i]=f[f[x][i-1]][i-1];
for(i=0;i<l1[x].size();i++)
{
o=l1[x][i];
if(o!=y) find(o,x,z+1);
}
b[x]=t;
}
int lca(int x,int y)
{
if(c[x]<c[y]) swap(x,y);
int z=c[x]-c[y],i;
for(i=0;i<20 && z;i++)
if(z&(1<<i)) x=f[x][i];
if(x==y) return x;
for(i=19;i>=0;i--)
if(f[x][i]!=f[y][i])
x=f[x][i],y=f[y][i];
return f[x][0];
}
void build()
{
sort(d+1,d+m+1,cmp);
int top=0,num=m,i;
for(i=1;i<num;i++)
d[++m]=lca(d[i],d[i+1]);
sort(d+1,d+m+1,cmp);
m=unique(d+1,d+m+1)-d-1;
s[++top]=d[1];
for(i=2;i<=m;i++)
{
while(top && b[s[top]]<a[d[i]]) top--;
if(top) l2[s[top]].push_back(d[i]);
s[++top]=d[i];
}
}
void dfs(int x)
{
int o,i;
ans[x]=0;
if(v[x])
{
k[x]=1;
for(i=0;i<l2[x].size();i++)
{
o=l2[x][i];
dfs(o);
ans[x]+=ans[o];
if(k[o]) ans[x]++;
}
}
else
{
k[x]=0;
for(i=0;i<l2[x].size();i++)
{
o=l2[x][i];
dfs(o);
ans[x]+=ans[o],k[x]+=k[o];
}
if(k[x]>1) k[x]=0,ans[x]++;
}
}例题推荐
1. Codeforces 613D - Kindom and its Cities
Tag: 模板题
AC代码
#include<bits/stdc++.h>
using namespace std;
const int M=200010;
int a[M],b[M],c[M],d[M],f[M][20],s[M],v[M],ans[M],k[M],n,q,m,t;
vector<int> l1[M],l2[M];
//上方模板
int main()
{
int x,y,i;
bool flag;
scanf("%d",&n);
for(i=1;i<=n;i++)
l1[i].clear();
for(i=1;i<n;i++)
{
scanf("%d%d",&x,&y);
l1[x].push_back(y);
l1[y].push_back(x);
}
find(1,0,0);
scanf("%d",&q);
while(q--)
{
flag=1;
scanf("%d",&m);
for(i=1;i<=m;i++)
{
scanf("%d",&d[i]);
v[d[i]]=1;
}
for(i=1;i<=m;i++)
if(v[f[d[i]][0]])
{
flag=0;
break;
}
if(!flag) puts("-1");
else
{
build();
dfs(d[1]);
printf("%d\n",ans[d[1]]);
}
for(i=1;i<=m;i++)
{
v[d[i]]=0;
l2[d[i]].clear();
}
}
}2. 洛谷P2495 [SDOI2011] - 消耗战
Tag: 模板题
3. HDU2196 - Computer
Tag: 虚树+换根+树上点到其他点的最长距离
4. POJ 3585 - Accumulation Degree
Tag: 虚树+换根+树上最大流
#include<iostream>
#include<cstdio>
#include<cstring>
using namespace std;
const int M=500010;
struct link
{
int x,y,z;
} a[M*2];
int b[M],c[M],dp[M],f[M],r,s,ans;
bool v[M];
void clear()
{
s=ans=0,r=1;
memset(v,0,sizeof(v));
memset(a,0,sizeof(a));
memset(dp,0,sizeof(dp));
memset(c,0,sizeof(c));
memset(b,0,sizeof(b));
memset(f,0,sizeof(f));
}
void add(int x,int y,int z)
{
a[++s].x=y,a[s].y=b[x],b[x]=s,a[s].z=z;
}
void dfs(int x)
{
int y,i;
v[x]=1;
for(i=b[x];i;i=a[i].y)
{
y=a[i].x;
if(!v[y])
{
dfs(y);
if(c[y]!=1) dp[x]+=min(a[i].z,dp[y]);
else dp[x]+=a[i].z;
}
}
}
void ddfs(int x)
{
int y,i;
v[x]=1;
for(i=b[x];i;i=a[i].y)
{
y=a[i].x;
if(!v[y])
{
if(c[x]!=1) f[y]=dp[y]+min(f[x]-min(dp[y],a[i].z),a[i].z);
else f[y]=dp[y]+a[i].z;
ans=max(ans,f[y]);
ddfs(y);
}
}
}
int main()
{
int t,x,y,z,n,i;
scanf("%d",&t);
while(t--)
{
clear();
scanf("%d",&n);
for(i=1;i<n;i++)
{
scanf("%d%d%d",&x,&y,&z);
c[x]++,c[y]++;
add(x,y,z);
add(y,x,z);
}
dfs(r);
memset(v,0,sizeof(v));
f[r]=dp[r];
ddfs(r);
printf("%d\n",ans);
}
return 0;
}5. 洛谷P4103 [HEOI2014] - 大工程
Tag: 虚树+树上最短/最长路径
AC代码
#include<bits/stdc++.h>
using namespace std;
const int N=1000010,inf=0x7f7f7f7f7f;
struct link
{
int a,b;
} l[2*N];
int a[N],p[N],q[N],f[22][N],t[N*4],n,m,tot,w;
long long c[N],g[N],mi[N],ma[N],d[N],ans1,ans2,ans3;
bool b[N],v[N];
stack<int> s;
void add(int x,int y)
{
l[++tot].a=y,l[tot].b=a[x],a[x]=tot;
}
void dfs(int x)
{
int y=a[x],z,i;
p[x]=++w,b[x]=1;
for(i=0;f[i][x];i++)
f[i+1][x]=f[i][f[i][x]];
while(y)
{
z=l[y].a;
if(!b[z])
{
d[z]=d[x]+1,f[0][z]=x;
dfs(z);
}
y=l[y].b;
}
q[x]=++w;
}
int lca(int x,int y)
{
if(d[x]<d[y]) swap(x,y);
int z=d[x]-d[y],i;
for(i=0;z;z>>=1,i++)
if(z%2) x=f[i][x];
if(x==y) return x;
for(i=21;i>=0;i--)
if(f[i][x]!=f[i][y])
x=f[i][x],y=f[i][y];
return f[0][x];
}
bool cmp(int x,int y)
{
int xx,yy;
if(x>0) xx=p[x];
else xx=q[-x];
if(y>0) yy=p[y];
else yy=q[-y];
return xx<yy;
}
int main()
{
int tot,x,y,z,i,j,k;
long long ss;
scanf("%d",&n);
for(i=1;i<=n;i++) mi[i]=inf;
for(i=1;i<n;i++)
{
scanf("%d%d",&x,&y);
add(x,y);
add(y,x);
}
dfs(1);
scanf("%d",&m);
for(i=1;i<=m;i++)
{
scanf("%d",&k);
ans1=ans3=0,ans2=inf,tot=k;
for(j=1;j<=k;j++)
{
scanf("%d",&t[j]);
mi[t[j]]=0,g[t[j]]=1,v[t[j]]=1;
}
sort(t+1,t+k+1,cmp);
t[++tot]=-t[1];
for(j=2;j<=k;j++)
{
z=lca(t[j],t[j-1]);
t[++tot]=-t[j];
if(!v[z])
t[++tot]=z,t[++tot]=-z,v[z]=1;
}
sort(t+1,t+tot+1,cmp);
for(j=1;j<=tot;j++)
{
if(t[j]>0)
{
s.push(t[j]);
continue;
}
if(t[j]<0)
{
x=s.top();
s.pop();
if(!s.empty())
{
y=s.top(),ss=(d[x]-d[y]);
c[x]+=g[x]*ss,ans1+=g[y]*c[x]+g[x]*c[y];
g[y]+=g[x],c[y]+=c[x];
mi[x]+=ss,ans2=min(ans2,mi[y]+mi[x]),mi[y]=min(mi[y],mi[x]);
ma[x]+=ss,ans3=max(ans3,ma[y]+ma[x]),ma[y]=max(ma[y],ma[x]);
}
g[x]=c[x]=ma[x]=0,mi[x]=inf,v[x]=0;
}
}
printf("%lld %lld %lld\n",ans1,ans2,ans3);
}
}
京公网安备 11010502036488号