Random Access Iterator

唉!你能一手双向边秒了我?场上SB了,竟然天真的以为出题人给我的数据时会按照从父节点到儿子节点的顺序给我单向边。。。然后WA到mdbrsl

题意:给定一棵以 1 1 1为根的树(以后存树千万记得存双向边!),然后问从 1 1 1开始能到达最大深度的概率。细节补充:从某个父节点可以等概率的选择到达某个儿子节点,并且选择次数为儿子数,即在多次移动中只要有一次遍历到最大深度即可

思路

  1. 一遍 d f s dfs dfs处理出最大深度,然后给最大深度的节点概率标记为 1 1 1,其他叶子节点标记为概率为 0 0 0
  2. 再跑一边 d f s dfs dfs,遍历过程中不断迭代计算从某个父节点能到达最大深度的概率,最后就可以得到从 1 1 1能到达最大深度的概率
  3. 而具体到每次的计算,我们当前节点的所有儿子节点的概率都算好了
  4. 然后我们先计算一次选择能到达最大深度的概率为 f 0 f_0 f0,显然是所有儿子的概率相加,然后除以儿子数
  5. 然后用 1 f 0 1-f_0 1f0表示一次选择不能到达最大深度的概率,再把这个值自乘 s s s次( s s s为儿子数),这就表示 s s s次选择都不能到达最大深度的概率
  6. 最后,再用 1 1 1来减去这个值,就可以得到 s s s次选择中至少有一次能到达最大深度的概率(听起来有点绕,但下面这个表达式肯定清晰)
  7. f 0 = 1 ( 1 i = 1 s f i s ) s f_0=1-(1-\frac{\sum_{i=1}^{s} f_i}{s})^s f0=1(1si=1sfi)s f 0 f_0 f0就是当前节点能到达最大深度的概率,除法记得用逆元

题面描述

#include "bits/stdc++.h"
#define hhh printf("hhh\n")
#define see(x) (cerr<<(#x)<<'='<<(x)<<endl)
using namespace std;
typedef long long ll;
typedef pair<int,int> pr;
inline int read() {int x=0;char c=getchar();while(c<'0'||c>'9')c=getchar();while(c>='0'&&c<='9')x=x*10+c-'0',c=getchar();return x;}

const int maxn = 1e6+10;
const int mod = 1e9+7;
const double eps = 1e-7;

int n;
ll fast(ll a, int k) {
    if(k==0) return 1;
    ll now=fast(a,k/2);
    now=now*now%mod;
    if(k&1) now=now*a%mod;
    return now;
}

ll inv(ll a) { return fast(a,mod-2); }

int head[maxn], nxt[maxn*2], to[maxn*2], tot;
int dp[maxn], vis[maxn], mx;
ll f[maxn];

inline void add_edge(int u, int v) {
    ++tot, to[tot]=v, nxt[tot]=head[u], head[u]=tot;
    ++tot, to[tot]=u, nxt[tot]=head[v], head[v]=tot; //千万记得双向边呀!!!WA死我了,赛后双向边乱A都行
}

void predfs(int u, int fa, int d) {
    dp[u]=d; mx=max(mx,d);
    for(int i=head[u]; i; i=nxt[i]) {
        if(to[i]==fa) continue;
        predfs(to[i],u,d+1);
    }
}

void dfs(int u, int fa) {
    if(vis[u]) { f[u]=1; return; }
    int s=0, ss=0;
    ll sum=0;
    for(int i=head[u]; i; i=nxt[i]) {
        if(to[i]==fa) continue;
        int v=to[i];
        dfs(v,u);
        s++;
        if(f[v]) ss++, sum=(sum+f[v])%mod;
    }
    ll ans=sum*inv(s)%mod;
    ans=(1-ans+mod)%mod;
    ans=fast(ans,s);
    ans=(1-ans+mod)%mod;
    f[u]=ans;
}

int main() {
    //ios::sync_with_stdio(false); cin.tie(0);
    n=read();
    for(int i=1; i<n; ++i) {
        int u=read(), v=read();
        add_edge(u,v);
    }
    predfs(1,0,1);
    for(int i=1; i<=n; ++i) if(dp[i]==mx) vis[i]=1;
    dfs(1,0);
    printf("%lld\n", f[1]);
}