题面:
题意:
有一棵 n 个节点的树,边权均为 1,从上面选 m 个点的方案为 Cnm。
对于每一种方案,该方案的权重定义为这 m 个点到树上某一点的距离和的最小值。我们定义这一点为最优点。
求 Cnm 种方案的权重的和。
题解:
我没枚举每一条边,假设这条边两侧的节点数分别为 s, n−s。我们在这条边两侧选的节点数为 i, m−i,我们可以知道,最优点一定在选的点数较多的一侧。
那么对于某条边来说较为容易得到公式:
f(s)=i=1∑m−1Csi∗Cn−sm−i∗min(i,m−i)
显然,对于每一条边,计算该式子的时间复杂度是 O(n2) 的。
考虑化简,我们令 p=⌊2m−1⌋,那么我们得到式子:
f(s)=i=1∑pCsi∗Cn−sm−i∗i+i=1∑pCsm−i∗Cn−si∗i+[m mod 2=0]∗Cs2m∗Cn−s2m∗2m
我们令 g(s)=i=1∑pCsi∗Cn−sm−i∗i , h(s)=i=1∑pCsm−i∗Cn−si∗i, k(s)=Cs2m∗Cn−s2m∗2m
容易发现 h(s)=g(n−s),现在我们考虑怎么快速求出 g(s)。
g(s)=i=1∑pCsi∗Cn−sm−i∗i
Csi=i!∗(s−i)!s!=(i−1)!∗i∗(s−i)!(s−1)!∗s=Cs−1i−1∗is
g(s)=s∗i=1∑pCs−1i−1∗Cn−sm−i=s∗t(s) ,其中 t(s)=i=1∑pCs−1i−1∗Cn−sm−i。
考虑给定 t(s) 一个定义:
n−1 个位置,放置 m−1 个球,每个球只能放在一个位置上,每个位置至多放置一个球。其中要求前 s−1 个位置至多放置 p−1 个球。
得到:
t(s)=i=1∑pCs−1i−1∗Cn−sm−i。
明显需要满足 p>=1。且 s=1时, t(s)=Cn−1m−1
考虑:怎么由 t(s−1) 得到 t(s)。
要求改变的地方为,从前 s−2 个位置至多放置 p−1 个球,转化为前 s−1 个位置至多放置 p−1 个球。
考虑哪些不合法。
那些在 t(s−1) 种合法且在 t(s) 种不合法的一定是,前 s−2 个位置已经放置了 p−1 个球,但是第 s−1 的位置还有一个球。即 Cs−2p−1∗Cn−sm−1−p。
这样,我们可以快速求出 t(s),从而快速得到 g(s),从而得到 h(s),最终得到 f(s)。
注意 m=1 和 m=2 这两种情况下, p=0。
代码:
#include<iostream>
#include<cstdio>
#include<cstdlib>
#include<algorithm>
#include<cstring>
#include<cmath>
#include<string>
#include<queue>
#include<bitset>
#include<map>
#include<unordered_map>
#include<unordered_set>
#include<set>
#include<ctime>
#define ui unsigned int
#define ll long long
#define llu unsigned ll
#define ld long double
#define pr make_pair
#define pb push_back
//#define lc (cnt<<1)
//#define rc (cnt<<1|1)
#define len(x) (t[(x)].r-t[(x)].l+1)
#define tmid ((l+r)>>1)
#define fhead(x) for(int i=head[(x)];i;i=nt[i])
#define max(x,y) ((x)>(y)?(x):(y))
#define min(x,y) ((x)>(y)?(y):(x))
using namespace std;
const int inf=0x3f3f3f3f;
const ll lnf=0x3f3f3f3f3f3f3f3f;
const double dnf=1e18;
const double alpha=0.75;
const int mod=1e9+7;
const double eps=1e-8;
const double pi=acos(-1.0);
const int hp=13331;
const int maxn=1000100;
const int maxm=100100;
const int maxp=100100;
const int up=1100;
ll fac[maxn],inv[maxn];
ll t[maxn],g[maxn],h[maxn],k[maxn],ans;
int f[maxn],si[maxn],n,m;
ll mypow(ll a,ll b)
{
ll ans=1;
while(b)
{
if(b&1) ans=ans*a%mod;
a=a*a%mod;
b>>=1;
}
return ans;
}
void init(void)
{
fac[0]=1;
for(int i=1;i<maxn;i++)
fac[i]=fac[i-1]*i%mod;
inv[maxn-1]=mypow(fac[maxn-1],mod-2);
for(int i=maxn-2;i>=0;i--)
inv[i]=inv[i+1]*(i+1)%mod;
}
ll C(ll n,ll m)
{
if(n<0||m<0||m>n) return 0;
return fac[n]*inv[m]%mod*inv[n-m]%mod;
}
int main(void)
{
init();
int tt;
scanf("%d",&tt);
while(tt--)
{
scanf("%d%d",&n,&m);
for(int i=2;i<=n;i++)
scanf("%d",&f[i]),si[i]=1;
ans=0;
int p=(m-1)/2;
t[1]=p?C(n-1,m-1):0;
g[1]=t[1]*1;
for(int s=2;s<=n;s++)
{
t[s]=((t[s-1]-C(s-2,p-1)*C(n-s,m-1-p)%mod)%mod+mod)%mod;
g[s]=t[s]*s%mod;
}
for(int s=1;s<=n;s++)
{
h[s]=g[n-s];
k[s]=C(s,m/2)*C(n-s,m/2)%mod*(m/2)%mod;
}
int now=0;
for(int i=n;i>=2;i--)
{
si[f[i]]+=si[i];
now=min(si[i],n-si[i]);
ans=(ans+g[now]+h[now]+(m%2==0?k[now]:0))%mod;
}
printf("%lld\n",ans);
}
return 0;
}