G - Glass Balls
题意:
- 给出一颗根为
的树,每个节点上刚开始都有一个球向它的父亲滚动,一个球在一个单位的时间内只能滚一条边。
- 存在一种特殊的节点,这种节点出现的概率为
,一旦有球滚到这个节点,这个球就被拿出这棵树。
- 存在一种特殊的情况,当两个球滚到同一个节点,不论这个节点是否特殊,整个系统就会崩溃掉,并且你得到的分数为
。
- 定义
为刚开始在第
个节点上的球滚过的边数,定义分数为
,求期望的分数
题解: 首先需要找到这题的突破口,那就是要先找到使系统不崩溃的概率。显然的是,要想系统不崩溃,对于一个父节点来说,若它有
个子节点,那这
个子节点中至少要有
个特殊节点,所以系统不崩溃的概率为
理解为有个特殊节点和
个特殊节点的和
接下来我们考虑如何计算期望。
首先作为孩子节点,它的期望可以从父亲节点那里转移过来,得到初步的转移式,但还需要考虑
节点的合法性,即若想从儿子节点出去,那么它必须是普通节点,那么对于
的所有儿子节点中,只有
这个节点为普通节点的概率为
所以完整的转移方程为
所以最后的答案别忘了乘上系统不崩溃的概率$$
#include<bits/stdc++.h> using namespace std; #define dbg(x...) do { cout << #x << " -> "; err(x); } while (0) void err () { cout << endl;} template <class T, class... Ts> void err(const T& arg, const Ts&... args) { cout << arg << ' '; err(args...);} #define ll long long #define ull unsigned long long #define LL __int128 #define inf 0x3f3f3f3f #define INF 0x3f3f3f3f3f3f3f3f #define pii pair<int, int> #define PII pair<ll, ll> #define tint tuple<int, int, int> #define fi first #define se second #define pb push_back #define eb emplace_back #define em emplace #define mp(a,b) make_pair(a,b) #define all(x) (x).begin(), (x).end() #define IOS ios::sync_with_stdio(false);cin.tie(0);cout.tie(0); #define PAUSE system("pause"); const double Pi = acos(-1.0); const double eps = 1e-8; const int maxn = 5e5 + 10; const int maxm = 1e5 + 10; const int mod = 998244353; inline ll rd() { ll f = 0; ll x = 0; char ch = getchar(); for (; !isdigit(ch); ch = getchar()) f |= (ch == '-'); for (; isdigit(ch); ch = getchar()) x = (x << 1) + (x << 3) + ch - '0'; if (f) x = -x; return x; } void out(ll a) {if(a<0)putchar('-'),a=-a;if(a>=10)out(a/10);putchar(a%10+'0');} #define pt(x) out(x),puts("") inline void swap(ll &a, ll &b){ll t = a; a = b; b = t;} inline void swap(int &a, int &b){int t = a; a = b; b = t;} inline ll min(ll a, ll b){return a < b ? a : b;} inline ll max(ll a, ll b){return a > b ? a : b;} ll qpow(ll n,ll k,ll mod) {ll ans=1;while(k){if(k&1)ans=ans*n%mod;n=n*n%mod;k>>=1;}return ans%mod;} mt19937 rng(chrono::steady_clock::now().time_since_epoch().count()); int n, p; vector<int> G[maxn]; ll dp[maxn]; void dfs(int now, int fa) { for(auto g : G[now]) { ll tmp = (1 - p + mod) * qpow(p, G[now].size() - 1, mod) % mod; tmp = tmp * qpow((G[now].size() * (1 - p + mod) % mod * qpow(p, G[now].size() - 1, mod) % mod + qpow(p, G[now].size(), mod)) % mod, mod - 2, mod) % mod; dp[g] = (1 + dp[now]) * tmp % mod; dfs(g, now); } } void solve() { scanf("%d %d", &n, &p); for(int i = 2, x; i <= n; i++) { scanf("%d", &x); G[x].eb(i); } ll P = 1; for(int i = 1; i <= n; i++) { ll sz = G[i].size(); if(sz == 0) continue; P = P * (sz % mod * (1 - p + mod) % mod * qpow(p, sz - 1, mod) % mod + qpow(p, sz, mod)) % mod; } dfs(1, 0); ll ans = 0; for(int i = 1; i <= n; i++) { ans = (ans + dp[i]) % mod; } ans = ans * P % mod; pt(ans); } int main() { // freopen("in.txt","r",stdin); // freopen("out.txt", "w", stdout); int t = 1; // t = rd(); // scanf("%d", &t); // cin >> t; while(t--) solve(); return 0; }