D Directed
题意:给定一个 个点以 为根的树,现在要从 出发到 号节点。现在随机选择 条树边变成单向边,方向由儿子指向父亲。同时人在某一个节点以等概率选择出边,问期望多少步走到 号节点。
解法:对于一个有根树,从儿子节点 走到父节点 ,其期望步数为 ,其中 表示 的子树大小。
由此,如果没有改为单向边的操作,考虑从 到 的期望步数,即是从 一步一步走到其祖先直到 的期望步数之和。记从 到 的路径为特殊路径 :,基础答案为 。
考虑有单向边的情况。如果一条边 (不妨令 是 的祖先)作为单向边,那么其祖先的子树大小会受到影响(等效于从祖先视角看, 这个子树被删除了), 的子树都不会受到影响。进一步的,由于统计的期望步数仅由 上的子树和决定,因而影响区间为 在 上的最近祖先至 的这一条 上的链。
考虑两种情况:
- 为 上的链。如果当前边不选,则对步数的贡献会减少 ——向上的一步还是要计入。考虑从 开始到 的每一步的期望贡献:对于 的一级祖先 , 的期望步数要减少 ,对于任何情况均成立,仅需要满足 被选定,方案数为 ,概率为 ;对于 的二级祖先 , 的期望步数要减少 而非 ,需要 不是单向边,否则从 上升到 时,贡献为 而不是 ,因而可选择范围为 —— 选定,且 必不选;三级祖先需要满足 和 均不是单向边,因而概率为 ;因而,若 到 的路径长度为 ,则总的概率为 ,对答案的贡献需要减去 。
- 不是 上的链。记 到链 上的最近祖先为 ,则 的贡献能被直接记录在 上而非 到 的其他祖先上的概率为 的路径上所有边都不允许是单向边。然后在 上从 到 的分析和上面相同。因而其概率为 ,对答案的贡献为 。
总时间复杂度仅为遍历复杂度,。
(队友代码)
#include<bits/stdc++.h>
#define IL inline
#define LL long long
using namespace std;
const int N=1e6+3,p=998244353;
struct hh{
int to,nxt;
}e[N<<1];
int n,k,s,num,fir[N],fac[N],ifac[N],bo[N],siz[N],dep[N],P[N],fa[N],bel[N],pp,ans;
IL int in(){
char c;int f=1;
while((c=getchar())<'0'||c>'9')
if(c=='-') f=-1;
int x=c-'0';
while((c=getchar())>='0'&&c<='9')
x=x*10+c-'0';
return x*f;
}
IL int mod(int x){return x>=p?x-p:x;}
IL void add(int x,int y){e[++num]=(hh){y,fir[x]},fir[x]=num;}
IL int ksm(int a,int b){
int c=1;
while(b){
if(b&1) c=1ll*c*a%p;
a=1ll*a*a%p,b>>=1;
}
return c;
}
void init(){
fac[0]=1;for(int i=1;i<=n;++i) fac[i]=1ll*fac[i-1]*i%p;
ifac[n]=ksm(fac[n],p-2);
for(int i=n;i;--i) ifac[i-1]=1ll*ifac[i]*i%p;
}
IL int C(int n,int m){if(n<m) return 0;return 1ll*fac[n]*ifac[m]%p*ifac[n-m]%p;}
void dfs1(int u,int f){
fa[u]=f,siz[u]=1,dep[u]=dep[f]+1;
for(int i=fir[u],v;v=e[i].to;i=e[i].nxt)
if(v^f) dfs1(v,u),siz[u]+=siz[v];
}
void dfs2(int u,int bl){
bel[u]=bl;
for(int i=fir[u],v;v=e[i].to;i=e[i].nxt)
if(v^fa[u]){
if(bo[v]) dfs2(v,v);
else dfs2(v,bl);
}
}
int main()
{
int x,y;
n=in(),k=in(),s=in(),init();
for(int i=1;i<n;++i)
x=in(),y=in(),add(x,y),add(y,x);
dfs1(1,0);
bo[s]=1;while(s!=1) bo[s=fa[s]]=1;
dfs2(1,1);
pp=ksm(C(n-1,k),p-2);
for(int i=1;i<n;++i) P[i]=1ll*C(n-1-i,k-1)*pp%p,P[i]=mod(P[i-1]+P[i]);
for(int i=2;i<=n;++i) if(bo[i]) ans=mod(ans+mod(2*siz[i]-1));
for(int i=2;i<=n;++i)
if(dep[i]>2){
if(bo[i]) ans=mod(ans-1ll*P[dep[i]-2]*mod(2*siz[i])%p+p);
else ans=mod(ans+p-1ll*mod(P[dep[i]-2]-P[dep[i]-dep[bel[i]]-1]+p)*mod(2*siz[i])%p);
}
printf("%d\n",ans);
return 0;
}