G - Glass Balls

题意:

  • 给出一颗根为的树,每个节点上刚开始都有一个球向它的父亲滚动,一个球在一个单位的时间内只能滚一条边。
  • 存在一种特殊的节点,这种节点出现的概率为,一旦有球滚到这个节点,这个球就被拿出这棵树。
  • 存在一种特殊的情况,当两个球滚到同一个节点,不论这个节点是否特殊,整个系统就会崩溃掉,并且你得到的分数为
  • 定义为刚开始在第个节点上的球滚过的边数,定义分数为,求期望的分数

题解: 首先需要找到这题的突破口,那就是要先找到使系统不崩溃的概率。显然的是,要想系统不崩溃,对于一个父节点来说,若它有个子节点,那这个子节点中至少要有个特殊节点,所以系统不崩溃的概率为
理解为有个特殊节点和个特殊节点的和
接下来我们考虑如何计算期望。
首先作为孩子节点,它的期望可以从父亲节点那里转移过来,得到初步的转移式,但还需要考虑节点的合法性,即若想从儿子节点出去,那么它必须是普通节点,那么对于的所有儿子节点中,只有这个节点为普通节点的概率为
所以完整的转移方程为
所以最后的答案别忘了乘上系统不崩溃的概率$$

#include<bits/stdc++.h>
using namespace std;

#define dbg(x...) do { cout << #x << " -> "; err(x); } while (0)
void err () {   cout << endl;}
template <class T, class... Ts>
void err(const T& arg, const Ts&... args) { cout << arg << ' '; err(args...);}

#define ll long long
#define ull unsigned long long
#define LL __int128
#define inf 0x3f3f3f3f
#define INF 0x3f3f3f3f3f3f3f3f
#define pii pair<int, int>
#define PII pair<ll, ll>
#define tint tuple<int, int, int> 
#define fi first
#define se second
#define pb push_back
#define eb emplace_back
#define em emplace
#define mp(a,b) make_pair(a,b)
#define all(x) (x).begin(), (x).end()
#define IOS ios::sync_with_stdio(false);cin.tie(0);cout.tie(0);
#define PAUSE system("pause");
const double Pi = acos(-1.0);
const double eps = 1e-8;
const int maxn = 5e5 + 10;
const int maxm = 1e5 + 10;
const int mod = 998244353;

inline ll rd() {
    ll f = 0; ll x = 0; char ch = getchar();
    for (; !isdigit(ch); ch = getchar()) f |= (ch == '-');
    for (; isdigit(ch); ch = getchar()) x = (x << 1) + (x << 3) + ch - '0';
    if (f) x = -x;
    return x;
}
void out(ll a) {if(a<0)putchar('-'),a=-a;if(a>=10)out(a/10);putchar(a%10+'0');}
#define pt(x) out(x),puts("")
inline void swap(ll &a, ll &b){ll t = a; a = b; b = t;}
inline void swap(int &a, int &b){int t = a; a = b; b = t;}
inline ll min(ll a, ll b){return a < b ? a : b;}
inline ll max(ll a, ll b){return a > b ? a : b;}
ll qpow(ll n,ll k,ll mod) {ll ans=1;while(k){if(k&1)ans=ans*n%mod;n=n*n%mod;k>>=1;}return ans%mod;}
mt19937 rng(chrono::steady_clock::now().time_since_epoch().count());

int n, p;
vector<int> G[maxn];
ll dp[maxn];

void dfs(int now, int fa) {
    for(auto g : G[now]) {
        ll tmp = (1 - p + mod) * qpow(p, G[now].size() - 1, mod) % mod;
        tmp = tmp * qpow((G[now].size() * (1 - p + mod) % mod * qpow(p, G[now].size() - 1, mod) % mod + qpow(p, G[now].size(), mod)) % mod, mod - 2, mod) % mod;
        dp[g] = (1 + dp[now]) * tmp % mod;
        dfs(g, now);
    }
}

void solve() {
    scanf("%d %d", &n, &p);
    for(int i = 2, x; i <= n; i++) {
        scanf("%d", &x);
        G[x].eb(i);
    }

    ll P = 1;
    for(int i = 1; i <= n; i++) {
        ll sz = G[i].size();
        if(sz == 0) continue;
        P = P * (sz % mod * (1 - p + mod) % mod * qpow(p, sz - 1, mod) % mod + qpow(p, sz, mod)) % mod;
    }
    dfs(1, 0);
    ll ans = 0;
    for(int i = 1; i <= n; i++) {
        ans = (ans + dp[i]) % mod;
    }
    ans = ans * P % mod;
    pt(ans);
}

int main() {
//    freopen("in.txt","r",stdin);
//    freopen("out.txt", "w", stdout);
    int t = 1;
    // t = rd();
    // scanf("%d", &t);
    // cin >> t;
    while(t--)
        solve();
    return 0;
}