题目链接

题面:

题意:
给定一棵有根树,根节点是1号节点,在上面选择k个关键点。每个点的权值定义为该点与其最近关键祖先的距离(其祖先中是关键点且距离该点最近的那一个),若其祖先中没有关键点则权值为正无穷,若其本身就是关键点则权值为0。整棵树的权值定义为所有点的权值的最大值。问通过某种方法安排这k个关键点,使得这棵树的权值最小。
输出 k = 1 n v a l ( t r e e k ) \sum_{k=1}^nval(tree_k) k=1nval(treek)

题解:
假设我们现在已经知道了这棵树的权值为 x,那么最少需要多少个关键点。

每次选择当前深度最深的点,将它的第 x 个祖先设为关键点,并且删除这个关键点的子树,直到整棵树被删完。这样做需要的关键点是最少的。因为每次选择的都是必须要选择的关键点。

显然,当这棵树的权值为 x 时,所需要的关键点的数量最多为 n x + 1 + 1 \frac{n}{x+1}+1 x+1n+1,因为选择一个关键点至少要删掉 x + 1 x+1 x+1 个点。

现在我们枚举这棵树的权值 x { 0 , 1 , 2 , 3 , . . . n } x\in\{0,1,2,3,...n\} x{0,1,2,3,...n},每次计算出权值为 x x x 时所需要的最少的关键点的数量 s u m [ x ] sum[x] sum[x],对于每一个关键点数量 c n t [ i ] cnt[i] cnt[i] 取一个 m i n ( x s u m [ x ] = i ) min(x|sum[x]=i) min(xsum[x]=i),然后再正向更新一遍 c n t [ i ] cnt[i] cnt[i],此时的 c n t cnt cnt 数组即为所求。

现在考虑怎么对于一个权值 x x x 求解所需要的最少的关键点的数量。
我们dfs序建立线段树,维护每个点的深度,每次找到深度最深的点,将他的第 x 个祖先节点设为关键点(倍增),并且删除这个关键点的子树,将这棵子树的权值全部加上负无穷。直到整棵线段树上所维护的每个点的深度都 < = 0 <=0 <=0 那么说明符合要求。此时所用的关键点数,就是权值 x x x 所需要的最少关键点。
退出时撤销以上操作,使得线段树维护的权值为每个点的深度。

时间复杂度分析:
O ( n 1 l o g n + n 2 l o g n + . . . + n n l o g n ) = O ( n l o g 2 n ) O(\frac{n}{1}logn+\frac{n}{2}logn+...+\frac{n}{n}logn)=O(nlog^2n) O(1nlogn+2nlogn+...+nnlogn)=O(nlog2n)

因为树的权值 x = 0 x=0 x=0 时,一定需要n个关键点,就没必要再去求一次了,直接记录即可(其实是如果 x=0 也求一次,那么就会TLE。)

代码:

#include<iostream>
#include<cstdio>
#include<cstdlib>
#include<algorithm>
#include<cstring>
#include<cmath>
#include<string>
#include<queue>
#include<bitset>
#include<map>
#include<unordered_map>
#include<set>
#include<list>
#include<ctime>
#define ui unsigned int
#define ll long long
#define llu unsigned ll
#define ld long double
#define pr make_pair
#define pb push_back
#define lc (cnt<<1)
#define rc (cnt<<1|1)
#define len(x) (t[(x)].r-t[(x)].l+1)
#define tmid ((l+r)>>1)
#define max(x,y) ((x)>(y)?(x):(y))
#define min(x,y) ((x)>(y)?(y):(x))
using namespace std;

const int inf=0x3f3f3f3f;
const ll lnf=0x3f3f3f3f3f3f3f3f;
const double dnf=1e18;
const int mod=1e9+7;
const double eps=1e-1;
const double pi=acos(-1.0);
const int hp=13331;
const int maxn=200100;
const int maxp=400100;
const int maxm=600100;
const int up=200000;

int head[maxn],ver[maxn<<1],nt[maxn<<1],tot=1;
int d[maxn],f[maxn][20],dfn[maxn],id[maxn],si[maxn],cm=0,tt;
int sum[maxn],cnt[maxn];
int st[maxn],top=0;
struct node
{
    int l,r,id;
    ll maxx,laz;
}t[maxn<<2];

void init(int n)
{
    for(int i=1;i<=n;i++)
        head[i]=0,cnt[i]=inf;
    tot=1,cm=0;
    tt=log2(n)+1;
    d[1]=1;
}

void add(int x,int y)
{
    ver[++tot]=y,nt[tot]=head[x],head[x]=tot;
}

void dfs(int x,int fa)
{
    dfn[x]=++cm,id[cm]=x;
    si[x]=1;
    for(int i=head[x];i;i=nt[i])
    {
        int y=ver[i];
        if(y==fa) continue;
        d[y]=d[x]+1;
        f[y][0]=x;
        for(int j=1;j<=tt;j++)
            f[y][j]=f[f[y][j-1]][j-1];
        dfs(y,x);
        si[x]+=si[y];
    }
}

void pushup(int cnt)
{
    t[cnt].maxx=max(t[lc].maxx,t[rc].maxx);
}

void pushdown(int cnt)
{
    if(t[cnt].laz)
    {
        t[lc].laz+=t[cnt].laz;
        t[rc].laz+=t[cnt].laz;
        t[lc].maxx+=t[cnt].laz;
        t[rc].maxx+=t[cnt].laz;
        t[cnt].laz=0;
    }
}

void build(int l,int r,int cnt)
{
    t[cnt].l=l,t[cnt].r=r;
    t[cnt].laz=t[cnt].maxx=0;
    if(l==r)
    {
        t[cnt].maxx=d[id[l]];
        return ;
    }
    build(l,tmid,lc);
    build(tmid+1,r,rc);
    pushup(cnt);
}

void change(int l,int r,int cnt,ll val)
{
    if(l<=t[cnt].l&&t[cnt].r<=r)
    {
        t[cnt].laz+=val;
        t[cnt].maxx+=val;
        return ;
    }
    pushdown(cnt);
    if(t[lc].r>=l) change(l,r,lc,val);
    if(t[rc].l<=r) change(l,r,rc,val);
    pushup(cnt);
}

int ask(int cnt)
{
    if(t[cnt].l==t[cnt].r)
        return id[t[cnt].l];
    pushdown(cnt);
    if(t[lc].maxx==t[cnt].maxx) return ask(lc);
    else return ask(rc);
}

int getsum(int x)
{
    top=0;
    int maxdx,keyx,dkeyx;
    while(true)
    {
        if(t[1].maxx<=0) break;
        maxdx=ask(1);
        if(d[maxdx]<=x+1)  keyx=1;
        else
        {
            dkeyx=d[maxdx]-x;
            for(int i=tt;i>=0;i--)
                if(d[f[maxdx][i]]>=dkeyx) maxdx=f[maxdx][i];
            keyx=maxdx;
        }
        st[++top]=keyx;
        change(dfn[keyx],dfn[keyx]+si[keyx]-1,1,-inf);
    }
    for(int i=1;i<=top;i++)
        change(dfn[st[i]],dfn[st[i]]+si[st[i]]-1,1,inf);
    return top;
}

int main(void)
{
    int n,x;
    while(scanf("%d",&n)!=EOF)
    {
        init(n);
        for(int i=2;i<=n;i++)
        {
            scanf("%d",&x);
            add(x,i);
            add(i,x);
        }
        dfs(1,0);
        build(1,n,1);
        sum[0]=n;
        for(int i=1;i<=n;i++) sum[i]=getsum(i);
        for(int i=0;i<=n;i++)
            cnt[sum[i]]=min(cnt[sum[i]],i);
        int ans=cnt[1];
        for(int i=2;i<=n;i++)
            cnt[i]=min(cnt[i-1],cnt[i]),ans+=cnt[i];
        printf("%d\n",ans);
    }
    return 0;
}