D Directed

题意:给定一个 nn 个点以 11 为根的树,现在要从 ss 出发到 11 号节点。现在随机选择 kk 条树边变成单向边,方向由儿子指向父亲。同时人在某一个节点以等概率选择出边,问期望多少步走到 11 号节点。

解法:对于一个有根树,从儿子节点 vv 走到父节点 uu,其期望步数为 Ev=2sizv1E_v=2siz_v-1,其中 sizvsiz_v 表示 vv 的子树大小。

由此,如果没有改为单向边的操作,考虑从 ss11 的期望步数,即是从 ss 一步一步走到其祖先直到 11 的期望步数之和。记从 ss11 的路径为特殊路径 llsv1v21s \to v_1 \to v_2 \to \cdots \to 1,基础答案为 vl,v12sizv1\displaystyle \sum_{v \in l, v \neq 1}2siz_v-1

考虑有单向边的情况。如果一条边 (u,v)(u,v) (不妨令 uuvv 的祖先)作为单向边,那么其祖先的子树大小会受到影响(等效于从祖先视角看,uu 这个子树被删除了),vv 的子树都不会受到影响。进一步的,由于统计的期望步数仅由 ll 上的子树和决定,因而影响区间为 uull 上的最近祖先至 11 的这一条 ll 上的链。

考虑两种情况:

  1. (u,v)(u,v)ll 上的链。如果当前边不选,则对步数的贡献会减少 2sizu2siz_u——向上的一步还是要计入。考虑从 uu 开始到 11 的每一步的期望贡献:对于 uu 的一级祖先 a1a_1ua1u \to a_1 的期望步数要减少 2sizu2siz_u,对于任何情况均成立,仅需要满足 vuv \to u 被选定,方案数为 (n11k1)\displaystyle {n-1-1 \choose k-1},概率为 Pr[1]=(n11k1)/(n1k)\Pr[1]=\displaystyle {n-1-1 \choose k-1}\left/{n-1 \choose k}\right.;对于 uu 的二级祖先 a2a_2a1a2a_1 \to a_2 的期望步数要减少 2sizu2siz_u 而非 2siza12siz_{a_1},需要 ua1u \to a_1 不是单向边,否则从 a2a_2 上升到 a3a_3 时,贡献为 2siza12siz_{a_1} 而不是 2sizu2siz_u,因而可选择范围为 Pr[2]=(n12k1)/(n1k)\Pr[2]=\displaystyle {n-1-2 \choose k-1}\left/{n-1 \choose k}\right.——vuv\to u 选定,且 ua1u \to a_1 必不选;三级祖先需要满足 ua1u \to a_1a1a2a_1 \to a_2 均不是单向边,因而概率为 Pr[3]=(n13k1)/(n1k)\displaystyle \Pr[3]={n-1-3 \choose k-1}\left/{n-1 \choose k}\right.;因而,若 uu11 的路径长度为 depudep_u,则总的概率为 i=1depu1Pr[i]\displaystyle \sum_{i=1}^{dep_u-1} \Pr[i],对答案的贡献需要减去 2sizui=1depu1Pr[i]\displaystyle 2siz_u\sum_{i=1}^{dep_u-1} \Pr[i]
  2. (u,v)(u,v) 不是 ll 上的链。记 uu 到链 ll 上的最近祖先为 aa,则sizusiz_u 的贡献能被直接记录在 ll 上而非 uuaa 的其他祖先上的概率为 uau \to a 的路径上所有边都不允许是单向边。然后在 ll 上从 aa11 的分析和上面相同。因而其概率为 i=depadepuPr[i]\displaystyle \sum_{i=dep_a}^{dep_u} \Pr[i],对答案的贡献为 2sizui=depadepuPr[i]\displaystyle 2siz_u\sum_{i=dep_a}^{dep_u} \Pr[i]

总时间复杂度仅为遍历复杂度,O(n)\mathcal O(n)

(队友代码)

#include<bits/stdc++.h>
#define IL inline
#define LL long long
using namespace std;
const int N=1e6+3,p=998244353;
struct hh{
	int to,nxt;
}e[N<<1];
int n,k,s,num,fir[N],fac[N],ifac[N],bo[N],siz[N],dep[N],P[N],fa[N],bel[N],pp,ans;
IL int in(){
	char c;int f=1;
	while((c=getchar())<'0'||c>'9')
	  if(c=='-') f=-1;
	int x=c-'0';
	while((c=getchar())>='0'&&c<='9')
	  x=x*10+c-'0';
	return x*f;
}
IL int mod(int x){return x>=p?x-p:x;}
IL void add(int x,int y){e[++num]=(hh){y,fir[x]},fir[x]=num;}
IL int ksm(int a,int b){
	int c=1;
	while(b){
		if(b&1) c=1ll*c*a%p;
		a=1ll*a*a%p,b>>=1;
	}
	return c;
}
void init(){
	fac[0]=1;for(int i=1;i<=n;++i) fac[i]=1ll*fac[i-1]*i%p;
	ifac[n]=ksm(fac[n],p-2);
	for(int i=n;i;--i) ifac[i-1]=1ll*ifac[i]*i%p;
}
IL int C(int n,int m){if(n<m) return 0;return 1ll*fac[n]*ifac[m]%p*ifac[n-m]%p;}
void dfs1(int u,int f){
	fa[u]=f,siz[u]=1,dep[u]=dep[f]+1;
	for(int i=fir[u],v;v=e[i].to;i=e[i].nxt)
	  if(v^f) dfs1(v,u),siz[u]+=siz[v];
}
void dfs2(int u,int bl){
	bel[u]=bl;
	for(int i=fir[u],v;v=e[i].to;i=e[i].nxt)
	  if(v^fa[u]){
	  	if(bo[v]) dfs2(v,v);
	  	else dfs2(v,bl);
	  }
}
int main()
{
	int x,y;
	n=in(),k=in(),s=in(),init();
	for(int i=1;i<n;++i)
	  x=in(),y=in(),add(x,y),add(y,x);
	dfs1(1,0);
	bo[s]=1;while(s!=1) bo[s=fa[s]]=1;
	dfs2(1,1);
	pp=ksm(C(n-1,k),p-2);
	for(int i=1;i<n;++i) P[i]=1ll*C(n-1-i,k-1)*pp%p,P[i]=mod(P[i-1]+P[i]);
	for(int i=2;i<=n;++i) if(bo[i]) ans=mod(ans+mod(2*siz[i]-1));
	for(int i=2;i<=n;++i)
	  if(dep[i]>2){
		  if(bo[i]) ans=mod(ans-1ll*P[dep[i]-2]*mod(2*siz[i])%p+p);
		  else ans=mod(ans+p-1ll*mod(P[dep[i]-2]-P[dep[i]-dep[bel[i]]-1]+p)*mod(2*siz[i])%p);
	  }
	printf("%d\n",ans);
	return 0;
}