题目链接

题面:

题意:
有一棵 n n n 个节点的树,边权均为 1 1 1,从上面选 m m m 个点的方案为 C n m C_n^m Cnm
对于每一种方案,该方案的权重定义为这 m m m 个点到树上某一点的距离和的最小值。我们定义这一点为最优点。
C n m C_n^m Cnm 种方案的权重的和。

题解:
我没枚举每一条边,假设这条边两侧的节点数分别为 s s s n s n-s ns。我们在这条边两侧选的节点数为 i i i m i m-i mi,我们可以知道,最优点一定在选的点数较多的一侧。

那么对于某条边来说较为容易得到公式:

f ( s ) = i = 1 m 1 C s i C n s m i m i n ( i , m i ) f(s)=\sum\limits_{i=1}^{m-1}C_s^i*C_{n-s}^{m-i}*min(i,m-i) f(s)=i=1m1CsiCnsmimin(i,mi)

显然,对于每一条边,计算该式子的时间复杂度是 O ( n 2 ) O(n^2) O(n2) 的。

考虑化简,我们令 p = m 1 2 p=\left\lfloor\frac{m-1}{2}\right\rfloor p=2m1,那么我们得到式子:

f ( s ) = i = 1 p C s i C n s m i i + i = 1 p C s m i C n s i i + [ m <mtext>   </mtext> m o d <mtext>   </mtext> 2 = 0 ] C s m 2 C n s m 2 m 2 f(s)=\sum\limits_{i=1}^pC_s^i*C_{n-s}^{m-i}*i+\sum\limits_{i=1}^pC_s^{m-i}*C_{n-s}^i*i+[m\space mod\space 2=0]*C_s^{\frac{m}{2}}*C_{n-s}^{\frac{m}{2}}*\frac{m}{2} f(s)=i=1pCsiCnsmii+i=1pCsmiCnsii+[m mod 2=0]Cs2mCns2m2m

我们令 g ( s ) = i = 1 p C s i C n s m i i g(s)=\sum\limits_{i=1}^pC_s^i*C_{n-s}^{m-i}*i g(s)=i=1pCsiCnsmii h ( s ) = i = 1 p C s m i C n s i i h(s)=\sum\limits_{i=1}^pC_s^{m-i}*C_{n-s}^i*i h(s)=i=1pCsmiCnsii k ( s ) = C s m 2 C n s m 2 m 2 k(s)=C_s^{\frac{m}{2}}*C_{n-s}^{\frac{m}{2}}*\frac{m}{2} k(s)=Cs2mCns2m2m

容易发现 h ( s ) = g ( n s ) h(s)=g(n-s) h(s)=g(ns),现在我们考虑怎么快速求出 g ( s ) g(s) g(s)

g ( s ) = i = 1 p C s i C n s m i i g(s)=\sum\limits_{i=1}^pC_s^i*C_{n-s}^{m-i}*i g(s)=i=1pCsiCnsmii

C s i = <mstyle displaystyle="true" scriptlevel="0"> s ! i ! ( s i ) ! </mstyle> = <mstyle displaystyle="true" scriptlevel="0"> ( s 1 ) ! s ( i 1 ) ! i ( s i ) ! </mstyle> = C s 1 i 1 <mstyle displaystyle="true" scriptlevel="0"> s i </mstyle> C_s^i=\dfrac{s!}{i!*(s-i)!}=\dfrac{(s-1)!*s}{(i-1)!*i*(s-i)!}=C_{s-1}^{i-1}*\dfrac{s}{i} Csi=i!(si)!s!=(i1)!i(si)!(s1)!s=Cs1i1is

g ( s ) = s i = 1 p C s 1 i 1 C n s m i = s t ( s ) g(s)=s*\sum\limits_{i=1}^pC_{s-1}^{i-1}*C_{n-s}^{m-i}=s*t(s) g(s)=si=1pCs1i1Cnsmi=st(s) ,其中 t ( s ) = i = 1 p C s 1 i 1 C n s m i t(s)=\sum\limits_{i=1}^pC_{s-1}^{i-1}*C_{n-s}^{m-i} t(s)=i=1pCs1i1Cnsmi

考虑给定 t ( s ) t(s) t(s) 一个定义:
n 1 n-1 n1 个位置,放置 m 1 m-1 m1 个球,每个球只能放在一个位置上,每个位置至多放置一个球。其中要求前 s 1 s-1 s1 个位置至多放置 p 1 p-1 p1 个球。
得到:
t ( s ) = i = 1 p C s 1 i 1 C n s m i t(s)=\sum\limits_{i=1}^pC_{s-1}^{i-1}*C_{n-s}^{m-i} t(s)=i=1pCs1i1Cnsmi

明显需要满足 p > = 1 p>=1 p>=1。且 s = 1 s=1 s=1时, t ( s ) = C n 1 m 1 t(s)=C_{n-1}^{m-1} t(s)=Cn1m1

考虑:怎么由 t ( s 1 ) t(s-1) t(s1) 得到 t ( s ) t(s) t(s)

要求改变的地方为,从前 s 2 s-2 s2 个位置至多放置 p 1 p-1 p1 个球,转化为前 s 1 s-1 s1 个位置至多放置 p 1 p-1 p1 个球。

考虑哪些不合法。
那些在 t ( s 1 ) t(s-1) t(s1) 种合法且在 t ( s ) t(s) t(s) 种不合法的一定是,前 s 2 s-2 s2 个位置已经放置了 p 1 p-1 p1 个球,但是第 s 1 s-1 s1 的位置还有一个球。即 C s 2 p 1 C n s m 1 p C_{s-2}^{p-1}*C_{n-s}^{m-1-p} Cs2p1Cnsm1p

这样,我们可以快速求出 t ( s ) t(s) t(s),从而快速得到 g ( s ) g(s) g(s),从而得到 h ( s ) h(s) h(s),最终得到 f ( s ) f(s) f(s)

注意 m = 1 m=1 m=1 m = 2 m=2 m=2 这两种情况下, p = 0 p=0 p=0

代码:

#include<iostream>
#include<cstdio>
#include<cstdlib>
#include<algorithm>
#include<cstring>
#include<cmath>
#include<string>
#include<queue>
#include<bitset>
#include<map>
#include<unordered_map>
#include<unordered_set>
#include<set>
#include<ctime>
#define ui unsigned int
#define ll long long
#define llu unsigned ll
#define ld long double
#define pr make_pair
#define pb push_back
//#define lc (cnt<<1)
//#define rc (cnt<<1|1)
#define len(x) (t[(x)].r-t[(x)].l+1)
#define tmid ((l+r)>>1)
#define fhead(x) for(int i=head[(x)];i;i=nt[i])
#define max(x,y) ((x)>(y)?(x):(y))
#define min(x,y) ((x)>(y)?(y):(x))
using namespace std;

const int inf=0x3f3f3f3f;
const ll lnf=0x3f3f3f3f3f3f3f3f;
const double dnf=1e18;
const double alpha=0.75;
const int mod=1e9+7;
const double eps=1e-8;
const double pi=acos(-1.0);
const int hp=13331;
const int maxn=1000100;
const int maxm=100100;
const int maxp=100100;
const int up=1100;

ll fac[maxn],inv[maxn];
ll t[maxn],g[maxn],h[maxn],k[maxn],ans;
int f[maxn],si[maxn],n,m;

ll mypow(ll a,ll b)
{
    ll ans=1;
    while(b)
    {
        if(b&1) ans=ans*a%mod;
        a=a*a%mod;
        b>>=1;
    }
    return ans;
}

void init(void)
{
    fac[0]=1;
    for(int i=1;i<maxn;i++)
        fac[i]=fac[i-1]*i%mod;
    inv[maxn-1]=mypow(fac[maxn-1],mod-2);
    for(int i=maxn-2;i>=0;i--)
        inv[i]=inv[i+1]*(i+1)%mod;
}

ll C(ll n,ll m)
{
    if(n<0||m<0||m>n) return 0;
    return fac[n]*inv[m]%mod*inv[n-m]%mod;
}

int main(void)
{
    init();
    int tt;
    scanf("%d",&tt);
    while(tt--)
    {
        scanf("%d%d",&n,&m);
        for(int i=2;i<=n;i++)
            scanf("%d",&f[i]),si[i]=1;
        ans=0;
        int p=(m-1)/2;
        t[1]=p?C(n-1,m-1):0;
        g[1]=t[1]*1;
        for(int s=2;s<=n;s++)
        {
            t[s]=((t[s-1]-C(s-2,p-1)*C(n-s,m-1-p)%mod)%mod+mod)%mod;
            g[s]=t[s]*s%mod;
        }
        for(int s=1;s<=n;s++)
        {
            h[s]=g[n-s];
            k[s]=C(s,m/2)*C(n-s,m/2)%mod*(m/2)%mod;
        }
        int now=0;
        for(int i=n;i>=2;i--)
        {
            si[f[i]]+=si[i];
            now=min(si[i],n-si[i]);
            ans=(ans+g[now]+h[now]+(m%2==0?k[now]:0))%mod;
        }
        printf("%lld\n",ans);

    }
    return 0;
}