虚树入门
前置知识
掌握求lca,dfs序,每个点的深度,两点之间的距离等操作
什么是虚树
虚树的核心思想,就是保留一棵树上有用的&关键的点,重新构建一棵“虚”树,更加快速地跑带关键点的树形dp。
怎么用虚树
这里就需要用到上面提到的求lca(最近公共祖先)技术。
我们对一棵虚树的要求是,节点总数最小(没有多余的累赘),包含题目指定的节点,以及它们的lca。稍加思索即可明白,在这种情况下节点会被最小化。
(因为比较易懂,就不画图了懒得画)
首先我们要在树上跑dfs,求出前缀的dfs序,再以dfs序来从小到大排序;同时我们维护一个栈,表示从树根到栈顶元素的这条链。
得到关键点的时,按照dfn值(dfs序)排序,从前往后扫描,应用栈中的信息,先将这颗虚树build()好。
具体如何操作呢?
设我们的栈为s,关键点为x,栈顶指针为y。
- 如果栈为空或栈中只有一个元素,就应该s[++y]=x,将关键点压在栈顶;
- 我们找出x与s[y]的lca,如果lca=s[y],说明x点接着s[y]点,延长树链。所以我们还是做s[++y]=x;
- 反之,如果lca≠s[y],说明x与s[y]分属lca的不同子树,而且虚树上包含s[y]的这颗子树已经完成:
因此,我们需要将lca包含s[y]子树的那部分退栈,并将这部分建边形成虚树。
如果lca不在栈中,那么就把lca压入栈。接着延长树链,将x加入栈中。
在每个点做一次,我们的虚树就构建好了,求链上最小值,接着就是树形dp的操作。
模板核心代码
void find(int x,int y,int z) { int o,i; a[x]=++t,c[x]=z,f[x][0]=y; for(i=1;i<20;i++) f[x][i]=f[f[x][i-1]][i-1]; for(i=0;i<l1[x].size();i++) { o=l1[x][i]; if(o!=y) find(o,x,z+1); } b[x]=t; } int lca(int x,int y) { if(c[x]<c[y]) swap(x,y); int z=c[x]-c[y],i; for(i=0;i<20 && z;i++) if(z&(1<<i)) x=f[x][i]; if(x==y) return x; for(i=19;i>=0;i--) if(f[x][i]!=f[y][i]) x=f[x][i],y=f[y][i]; return f[x][0]; } void build() { sort(d+1,d+m+1,cmp); int top=0,num=m,i; for(i=1;i<num;i++) d[++m]=lca(d[i],d[i+1]); sort(d+1,d+m+1,cmp); m=unique(d+1,d+m+1)-d-1; s[++top]=d[1]; for(i=2;i<=m;i++) { while(top && b[s[top]]<a[d[i]]) top--; if(top) l2[s[top]].push_back(d[i]); s[++top]=d[i]; } } void dfs(int x) { int o,i; ans[x]=0; if(v[x]) { k[x]=1; for(i=0;i<l2[x].size();i++) { o=l2[x][i]; dfs(o); ans[x]+=ans[o]; if(k[o]) ans[x]++; } } else { k[x]=0; for(i=0;i<l2[x].size();i++) { o=l2[x][i]; dfs(o); ans[x]+=ans[o],k[x]+=k[o]; } if(k[x]>1) k[x]=0,ans[x]++; } }
例题推荐
1. Codeforces 613D - Kindom and its Cities
Tag: 模板题
AC代码
#include<bits/stdc++.h> using namespace std; const int M=200010; int a[M],b[M],c[M],d[M],f[M][20],s[M],v[M],ans[M],k[M],n,q,m,t; vector<int> l1[M],l2[M]; //上方模板 int main() { int x,y,i; bool flag; scanf("%d",&n); for(i=1;i<=n;i++) l1[i].clear(); for(i=1;i<n;i++) { scanf("%d%d",&x,&y); l1[x].push_back(y); l1[y].push_back(x); } find(1,0,0); scanf("%d",&q); while(q--) { flag=1; scanf("%d",&m); for(i=1;i<=m;i++) { scanf("%d",&d[i]); v[d[i]]=1; } for(i=1;i<=m;i++) if(v[f[d[i]][0]]) { flag=0; break; } if(!flag) puts("-1"); else { build(); dfs(d[1]); printf("%d\n",ans[d[1]]); } for(i=1;i<=m;i++) { v[d[i]]=0; l2[d[i]].clear(); } } }
2. 洛谷P2495 [SDOI2011] - 消耗战
Tag: 模板题
3. HDU2196 - Computer
Tag: 虚树+换根+树上点到其他点的最长距离
4. POJ 3585 - Accumulation Degree
Tag: 虚树+换根+树上最大流
#include<iostream> #include<cstdio> #include<cstring> using namespace std; const int M=500010; struct link { int x,y,z; } a[M*2]; int b[M],c[M],dp[M],f[M],r,s,ans; bool v[M]; void clear() { s=ans=0,r=1; memset(v,0,sizeof(v)); memset(a,0,sizeof(a)); memset(dp,0,sizeof(dp)); memset(c,0,sizeof(c)); memset(b,0,sizeof(b)); memset(f,0,sizeof(f)); } void add(int x,int y,int z) { a[++s].x=y,a[s].y=b[x],b[x]=s,a[s].z=z; } void dfs(int x) { int y,i; v[x]=1; for(i=b[x];i;i=a[i].y) { y=a[i].x; if(!v[y]) { dfs(y); if(c[y]!=1) dp[x]+=min(a[i].z,dp[y]); else dp[x]+=a[i].z; } } } void ddfs(int x) { int y,i; v[x]=1; for(i=b[x];i;i=a[i].y) { y=a[i].x; if(!v[y]) { if(c[x]!=1) f[y]=dp[y]+min(f[x]-min(dp[y],a[i].z),a[i].z); else f[y]=dp[y]+a[i].z; ans=max(ans,f[y]); ddfs(y); } } } int main() { int t,x,y,z,n,i; scanf("%d",&t); while(t--) { clear(); scanf("%d",&n); for(i=1;i<n;i++) { scanf("%d%d%d",&x,&y,&z); c[x]++,c[y]++; add(x,y,z); add(y,x,z); } dfs(r); memset(v,0,sizeof(v)); f[r]=dp[r]; ddfs(r); printf("%d\n",ans); } return 0; }
5. 洛谷P4103 [HEOI2014] - 大工程
Tag: 虚树+树上最短/最长路径
AC代码
#include<bits/stdc++.h> using namespace std; const int N=1000010,inf=0x7f7f7f7f7f; struct link { int a,b; } l[2*N]; int a[N],p[N],q[N],f[22][N],t[N*4],n,m,tot,w; long long c[N],g[N],mi[N],ma[N],d[N],ans1,ans2,ans3; bool b[N],v[N]; stack<int> s; void add(int x,int y) { l[++tot].a=y,l[tot].b=a[x],a[x]=tot; } void dfs(int x) { int y=a[x],z,i; p[x]=++w,b[x]=1; for(i=0;f[i][x];i++) f[i+1][x]=f[i][f[i][x]]; while(y) { z=l[y].a; if(!b[z]) { d[z]=d[x]+1,f[0][z]=x; dfs(z); } y=l[y].b; } q[x]=++w; } int lca(int x,int y) { if(d[x]<d[y]) swap(x,y); int z=d[x]-d[y],i; for(i=0;z;z>>=1,i++) if(z%2) x=f[i][x]; if(x==y) return x; for(i=21;i>=0;i--) if(f[i][x]!=f[i][y]) x=f[i][x],y=f[i][y]; return f[0][x]; } bool cmp(int x,int y) { int xx,yy; if(x>0) xx=p[x]; else xx=q[-x]; if(y>0) yy=p[y]; else yy=q[-y]; return xx<yy; } int main() { int tot,x,y,z,i,j,k; long long ss; scanf("%d",&n); for(i=1;i<=n;i++) mi[i]=inf; for(i=1;i<n;i++) { scanf("%d%d",&x,&y); add(x,y); add(y,x); } dfs(1); scanf("%d",&m); for(i=1;i<=m;i++) { scanf("%d",&k); ans1=ans3=0,ans2=inf,tot=k; for(j=1;j<=k;j++) { scanf("%d",&t[j]); mi[t[j]]=0,g[t[j]]=1,v[t[j]]=1; } sort(t+1,t+k+1,cmp); t[++tot]=-t[1]; for(j=2;j<=k;j++) { z=lca(t[j],t[j-1]); t[++tot]=-t[j]; if(!v[z]) t[++tot]=z,t[++tot]=-z,v[z]=1; } sort(t+1,t+tot+1,cmp); for(j=1;j<=tot;j++) { if(t[j]>0) { s.push(t[j]); continue; } if(t[j]<0) { x=s.top(); s.pop(); if(!s.empty()) { y=s.top(),ss=(d[x]-d[y]); c[x]+=g[x]*ss,ans1+=g[y]*c[x]+g[x]*c[y]; g[y]+=g[x],c[y]+=c[x]; mi[x]+=ss,ans2=min(ans2,mi[y]+mi[x]),mi[y]=min(mi[y],mi[x]); ma[x]+=ss,ans3=max(ans3,ma[y]+ma[x]),ma[y]=max(ma[y],ma[x]); } g[x]=c[x]=ma[x]=0,mi[x]=inf,v[x]=0; } } printf("%lld %lld %lld\n",ans1,ans2,ans3); } }