题目翻译

有一棵 \(n\) 个节点的树 ( $n \le 2 \times 10^5 $ ),现在要求选出 \(k\) 个节点,使得这 \(k\) 个节点到根节点的最短路径中,每个节点经过的剩余 \(n-k\) 个节点的数量之和最大。

思路

注:

  • 这里所说的 \(u\) 的子树不包含 \(u\) 本身。
  • \(dep\) 指深度,\(siz\) 指子树大小(包含子树的根)。

先 dfs 一遍,把树的深度和子树大小处理出来,然后贪心。
如果直接按照深度或者子树大小排序都是错的,这是因为新选一个节点可能会导致答案减小。所以我们需要一种更优的排序策略。
因为选 \(u\) 答案增加 \(dep_u-1\),而选 \(u\) 子树中的节点增加的值一定大于 \(dep_u-1\)。而且选 \(u\) 子树中的节点,减少的值一定小于等于选 \(u\) 节点。所以如果要选 \(u\) 节点,那么 \(u\) 结点的子树一定都选完了。
所以我们排序的时候就认为当选 \(u\) 时,它的子树都已经被选完。
所以选节点 \(u\) 会增加的值就是:

  • \((dep_u-1) \times siz_u\)

选节点 \(u\) 会减少的值就是:

  • \((siz_u-1)\times dep_u\)

选节点 \(u\) 对整个答案的贡献就是:

  • \((dep_u-1)\times siz_u-(siz_u-1) \times dep_u\)

按照这个贡献排序,选前 \(k\) 名即可。
答案再 dfs 一遍就可以得到。

Code

#include<iostream>
#include<algorithm>
#include<cstring>
#include<cstdio>
#include<bitset>
#include<cmath>
#include<queue>
#include<map>
#include<set>
#define LL long long

using namespace std;

LL read()
{
	LL ans=0,f=1;
	char c=getchar();
	while(c>'9'||c<'0'){if(c=='-')f=-1;c=getchar();}
	while(c>='0'&&c<='9'){ans=ans*10+c-'0';c=getchar();}
	return ans*f;
}

const LL N=1e6+5;
LL n,k,is[N],sum[N],ans;

struct Edge
{
	LL v,ne;
}e[N];
LL head[N],tot;

struct Node
{
	LL dep,x,siz;
    // 分别代表 深度 ,编号 ,子树大小
	bool operator < (const Node &x)const
	{
		return ((x.dep-1)*x.siz-x.dep*(x.siz-1))<((dep-1)*siz-dep*(siz-1));
        // 这是上面推出的排序策略
	}
}node[N];

inline void add(LL u,LL v);
void dfs1(LL u,LL fa);
void dfs2(LL u,LL fa);

int main()
{
	n=read();k=read();
	for(LL i=1;i<=n-1;++i)
	{
		LL u=read(),v=read();
		add(u,v);
		add(v,u);
        // 注意无向边
	}
	dfs1(1,0);
	sort(node+1,node+1+n);
	for(LL i=1;i<=k;++i)
    // 选出前 k 个节点
		is[node[i].x]=1;
	dfs2(1,0);
	printf("%lld",ans);
	return 0;
}

inline void add(LL u,LL v)
{
	e[++tot]=(Edge){v,head[u]};
	head[u]=tot;
}

void dfs1(LL u,LL fa)
// 第一遍 dfs 预处理出每个节点的深度与子树大小
{
	node[u].dep=node[fa].dep+1;
	node[u].siz=1;
	node[u].x=u;
	for(LL i=head[u];i;i=e[i].ne)
	{
		LL v=e[i].v;
		if(v==fa)
			continue;
		dfs1(v,u);
		node[u].siz+=node[v].siz;
	}
}

void dfs2(LL u,LL fa)
// 第二遍 dfs 利用前缀和求答案
{
	if(!is[u])
		sum[u]++;
	else
		ans+=sum[u];
	for(LL i=head[u];i;i=e[i].ne)
	{
		LL v=e[i].v;
		if(v==fa)
			continue;
		sum[v]=sum[u];
		dfs2(v,u);
	}
}