题目大意
给定N个点构成的有根树,顶点编号从1-N,根节点为1号点。你可以选最多K个点(根必须选),使得所有点的最大“祖先距离”尽可能的小。
点x的“祖先距离”是在点x到根节点上的路径上,点x与第一个关键点的距离。若没有关键点,则距离为正无穷大。(例如1-2-3树上,关键点为2,则三个点的“祖先距离”分别为{+∞,0,1}。)
解题思路
借鉴了https://blog.csdn.net/tianyizhicheng/article/details/107512243的做法。
当答案一定的时候会有多种的k(比如说一条长度为4的链,放2个关键点和3个关键点的答案是一样的,都是1)。
利用这个特性在上面做线段树,每次暴力找出答案为m的时候,要放多少个点。
这里用到了答案不同时k的不连续性,于是可以对k进行分块处理,分块后跑线段树,节省时间。
AC代码
#include<bits/stdc++.h> using namespace std; int a[200010],f[200010],d[200010],p[200010],n,t; long long b[200010]; vector<int> v[200010]; void dfs(int x) { t++,p[t]=x; for(int i=0;i<v[x].size();i++) { d[v[x][i]]=d[x]+1; dfs(v[x][i]); } } void build(int l,int r,int u,int v) //u为当前关键点数量的左界,v为右界。 { if(l>r || u>v) return; if(u==v) //l长度需要的关键点的数量是u { for(int i=l;i<=r;i++) b[i]=u; return; } int m=(l+r)/2,i; b[m]=0; for(i=1;i<=n;i++) a[i]=d[i]; for(i=n;i>=1;i--) { int x=p[i]; if(x==1 || a[x]==d[x]+m) a[x]=-1,b[m]++; a[f[x]]=max(a[f[x]],a[x]); } build(l,m-1,b[m],v); build(m+1,r,u,b[m]); } int main() { int m,i; long long ans; while(scanf("%d",&n)!=EOF) { for(i=1;i<=n;i++) v[i].clear(); t=m=ans=0; for(i=2;i<=n;i++) { scanf("%d",&f[i]); v[f[i]].push_back(i); } dfs(1); for(i=1;i<=n;i++) m=max(m,d[i]); build(0,m,1,n); for(i=1;i<=m;i++) ans+=1ll*i*(b[i-1]-b[i]); printf("%lld\n",ans); } }
附:官方标程
摘自官方题解:
按顺序枚举 1~k,争取每次验证二分时,把复杂度和 N剥离开来,搞成和关键点数量有关。
如果每次验证的复杂度是 关键点数量*log,那么容易证明总复杂度是𝑁(log 𝑁)2的(还有一个log是调和级数)。
(什么是调和级数?https://zhuanlan.zhihu.com/p/95763963)
我们要使复杂度和最终放置的关键点个数有关,用线段树去维护整棵树的 DFS 序。
每次的操作是:找到最深的还未被染色的节点 p;将 p 向上 x 步的节点 q 置为关键点,把 q 的子树区间染色。
这样每次验证的复杂度是 关键点×log 𝑁。
(代码有点丑)
#include<bits/stdc++.h> #define N 200005 using namespace std; int tot,last[N],to[N],Next[N],q[N],ans[N],st[N],en[N]; int dis[N],F[N],n,m,A[N],fa[N],si[N]; struct node { int s,v,d; }f[N*4]; inline void add(int x,int y) { Next[++tot]=last[x]; last[x]=tot; to[tot]=y; } inline void dfs(int x,int y) { if (!y) y=x; F[x]=y; st[x]=++tot; A[tot]=x; int gt=0,gtw=0; for (int i=last[x];i;i=Next[i]) { if (si[to[i]]>gt) gt=si[gtw=to[i]]; } if (gtw) dfs(gtw,y); for (int i=last[x];i;i=Next[i]) if (to[i]!=gtw) dfs(to[i],0); en[x]=tot; } inline int Max(int x,int y) { if (!x||!y) return x+y; if (dis[x]>dis[y]) return x; return y; } inline void up(int x) { f[x].v=Max(f[x*2].v,f[x*2+1].v); } inline void build(int o,int l,int r) { if (l==r) { f[o].v=f[o].s=A[l]; return ; } int mid=(l+r)>>1; build(o*2,l,mid),build(o*2+1,mid+1,r); up(o); f[o].s=Max(f[o*2].s,f[o*2+1].s); } inline void change(int o,int l,int r,int ll,int rr,int p) { if (p&&f[o].d) return ; if (f[o].d) { if (l!=r) { f[o*2].d=f[o*2+1].d=1; f[o*2].v=f[o*2].s; f[o*2+1].v=f[o*2+1].s; } f[o].d=0; } if (l==ll&&r==rr) { if (p) f[o].v=f[o].s,f[o].d=1; else f[o].v=0; return ; } int mid=(l+r)>>1; if (rr<=mid) change(o*2,l,mid,ll,rr,p); else if (ll>mid) change(o*2+1,mid+1,r,ll,rr,p); else change(o*2,l,mid,ll,mid,p),change(o*2+1,mid+1,r,mid+1,rr,p); up(o); } inline int find(int x,int k) { if (dis[x]<=k) return 1; while (1) { if (k<=dis[x]-dis[F[x]]) return A[st[x]-k]; k-=(dis[x]-dis[F[x]]+1); x=fa[F[x]]; } } inline int del(int x,int k) { int p=find(x,k); change(1,1,n,st[p],en[p],0); return p; } inline void ins(int x) { change(1,1,n,st[x],en[x],1); } inline int get(int k) { int r=0; while (1) { if (f[1].v==0) break; q[++r]=del(f[1].v,k); } for (int i=r;i;i--) ins(q[i]); return r; } int main() { while (scanf("%d",&n)!=EOF) { for (int i=1;i<=n;i++) si[i]=last[i]=0,ans[i]=n+1; tot=0; for (int i=2;i<=n;i++) { int x; scanf("%d",&x); assert(x<i); fa[i]=x; add(x,i); } for (int i=n;i;i--) si[i]++,si[fa[i]]+=si[i]; for (int i=1;i<=n;i++) dis[i]=dis[fa[i]]+1; tot=0; dfs(1,0); build(1,1,n); for (int i=n;i>=0;i--) ans[get(i)]=i; for (int i=2;i<=n;i++) ans[i]=min(ans[i],ans[i-1]); long long Ans=0; for (int i=1;i<n;i++) Ans+=ans[i]; printf("%lld\n",Ans); } }