题目大意

给定N个点构成的有根树,顶点编号从1-N,根节点为1号点。你可以选最多K个点(根必须选),使得所有点的最大“祖先距离”尽可能的小。
点x的“祖先距离”是在点x到根节点上的路径上,点x与第一个关键点的距离。若没有关键点,则距离为正无穷大。(例如1-2-3树上,关键点为2,则三个点的“祖先距离”分别为{+∞,0,1}。)

解题思路

借鉴了https://blog.csdn.net/tianyizhicheng/article/details/107512243的做法。
当答案一定的时候会有多种的k(比如说一条长度为4的链,放2个关键点和3个关键点的答案是一样的,都是1)。
利用这个特性在上面做线段树,每次暴力找出答案为m的时候,要放多少个点。
这里用到了答案不同时k的不连续性,于是可以对k进行分块处理,分块后跑线段树,节省时间。

AC代码

#include<bits/stdc++.h>
using namespace std;
int a[200010],f[200010],d[200010],p[200010],n,t;
long long b[200010];
vector<int> v[200010];
void dfs(int x)
{
	t++,p[t]=x;
	for(int i=0;i<v[x].size();i++)
	{
		d[v[x][i]]=d[x]+1;
		dfs(v[x][i]);
	}
}
void build(int l,int r,int u,int v) //u为当前关键点数量的左界,v为右界。 {
	if(l>r || u>v) return;
	if(u==v) //l长度需要的关键点的数量是u {
		for(int i=l;i<=r;i++) b[i]=u;
		return;
	}
	int m=(l+r)/2,i; b[m]=0;
	for(i=1;i<=n;i++) a[i]=d[i];
	for(i=n;i>=1;i--)
	{
		int x=p[i];
		if(x==1 || a[x]==d[x]+m) a[x]=-1,b[m]++;
		a[f[x]]=max(a[f[x]],a[x]);
	}
	build(l,m-1,b[m],v); build(m+1,r,u,b[m]);
}
int main()
{
	int m,i; long long ans; 
        while(scanf("%d",&n)!=EOF)
	{
		for(i=1;i<=n;i++) v[i].clear();
		t=m=ans=0;
		for(i=2;i<=n;i++)
		{
			scanf("%d",&f[i]);
			v[f[i]].push_back(i);
		}
		dfs(1);
		for(i=1;i<=n;i++) m=max(m,d[i]);
		build(0,m,1,n);
		for(i=1;i<=m;i++)
			ans+=1ll*i*(b[i-1]-b[i]);
		printf("%lld\n",ans);
	}
}

附:官方标程

摘自官方题解:
按顺序枚举 1~k,争取每次验证二分时,把复杂度和 N剥离开来,搞成和关键点数量有关。
如果每次验证的复杂度是 关键点数量*log,那么容易证明总复杂度是𝑁(log 𝑁)2的(还有一个log是调和级数)。
(什么是调和级数?https://zhuanlan.zhihu.com/p/95763963)

我们要使复杂度和最终放置的关键点个数有关,用线段树去维护整棵树的 DFS 序。
每次的操作是:找到最深的还未被染色的节点 p;将 p 向上 x 步的节点 q 置为关键点,把 q 的子树区间染色。
 这样每次验证的复杂度是 关键点×log 𝑁。
(代码有点丑)
#include<bits/stdc++.h>
#define N 200005
using namespace std;
int tot,last[N],to[N],Next[N],q[N],ans[N],st[N],en[N];
int dis[N],F[N],n,m,A[N],fa[N],si[N];
struct node {
	int s,v,d;
}f[N*4];
inline void add(int x,int y) {
	Next[++tot]=last[x]; last[x]=tot; to[tot]=y;
}
inline void dfs(int x,int y) {
	if (!y) y=x;
	F[x]=y;
	st[x]=++tot;
	A[tot]=x;
	int gt=0,gtw=0;
	for (int i=last[x];i;i=Next[i]) {
		if (si[to[i]]>gt) gt=si[gtw=to[i]];
	}
	if (gtw) dfs(gtw,y);
	for (int i=last[x];i;i=Next[i]) 
		if (to[i]!=gtw) dfs(to[i],0);
	en[x]=tot;
}
inline int Max(int x,int y) {
	if (!x||!y) return x+y;
	if (dis[x]>dis[y]) return x;
	return y;
}
inline void up(int x) {
	f[x].v=Max(f[x*2].v,f[x*2+1].v);
}
inline void build(int o,int l,int r) {
	if (l==r) {
		f[o].v=f[o].s=A[l];
		return ;
	}
	int mid=(l+r)>>1;
	build(o*2,l,mid),build(o*2+1,mid+1,r);
	up(o);
	f[o].s=Max(f[o*2].s,f[o*2+1].s);
}
inline void change(int o,int l,int r,int ll,int rr,int p) {
	if (p&&f[o].d) return ;
	if (f[o].d) {
		if (l!=r) {		
			f[o*2].d=f[o*2+1].d=1;
			f[o*2].v=f[o*2].s;
			f[o*2+1].v=f[o*2+1].s;
		}
		f[o].d=0;
	}
	if (l==ll&&r==rr) {
		if (p) f[o].v=f[o].s,f[o].d=1;
		else f[o].v=0;
		return ;
	}
	
	int mid=(l+r)>>1;
	if (rr<=mid) change(o*2,l,mid,ll,rr,p);
	else if (ll>mid) change(o*2+1,mid+1,r,ll,rr,p);
	else change(o*2,l,mid,ll,mid,p),change(o*2+1,mid+1,r,mid+1,rr,p);
	up(o);
}
inline int find(int x,int k) {
	if (dis[x]<=k) return 1;
	while (1) {
		if (k<=dis[x]-dis[F[x]]) return A[st[x]-k];
		k-=(dis[x]-dis[F[x]]+1);
		x=fa[F[x]];
	}
}
inline int del(int x,int k) {
	int p=find(x,k);
	change(1,1,n,st[p],en[p],0);
	return p;
}
inline void ins(int x) {
	change(1,1,n,st[x],en[x],1);
}
inline int get(int k) {
	int r=0;
	while (1) {
		if (f[1].v==0) break;
		q[++r]=del(f[1].v,k);
	}
	for (int i=r;i;i--) ins(q[i]);
	return r;
}
int main() {
	while (scanf("%d",&n)!=EOF) {
		for (int i=1;i<=n;i++) si[i]=last[i]=0,ans[i]=n+1;
		tot=0;
		for (int i=2;i<=n;i++) {
			int x;
			scanf("%d",&x);
			assert(x<i);
			fa[i]=x;
			add(x,i);
			
		}
		for (int i=n;i;i--) si[i]++,si[fa[i]]+=si[i];
		for (int i=1;i<=n;i++) dis[i]=dis[fa[i]]+1;
		tot=0;
		dfs(1,0);
		build(1,1,n);
		for (int i=n;i>=0;i--) ans[get(i)]=i;
		for (int i=2;i<=n;i++) ans[i]=min(ans[i],ans[i-1]);
		long long Ans=0;
		for (int i=1;i<n;i++) Ans+=ans[i];
		printf("%lld\n",Ans);
	}
}