题目大意

你有一个随机数生成器,他会随机生成之间某个数,生成的概率为

现在分为大体步操作:

  1. 随机生成一个
  2. 如果这个​是已经生成的数中最大的,返回步骤一继续生成新的数,否则进入步骤三
  3. 游戏结束,本局游戏的得分为生成序列的长度的平方​。

现在要你求出这局游戏的得分期望

Solution

考点:期望dp

这个期望不是很好求,我们优先考虑求出

我们假设​​​是随机到​​​​后还能进行的期望抽取次数,那么我们知道​​,那么我们把之后的抽取分为​​​这三类,可以知道,抽到​​​之后就无法进行了这时候的期望次数就是​​​,如果抽到了相同的数​​​,这时候他带来的期望次数就是​​​,如果抽到了​​​,他们的期望次数是​​​。

我们总结上面三种情况就可以得到下面式子:

进行移相化简,并且知道​之后我们可以得到这样的最终式子。

看出这个是一个倒序的递推式,倒序的处理即可求到的全部值,也就是


接下来题目要求的是的期望,在概率论中,有这样的式子,也就是方差等于平方的期望减掉期望的平方,所以我们不能把划等号。

和求​同理,我们假设​代表随机到​后游戏的期望得分也就是剩余回合数的平方,那么我们知道

​同样的对下次抽取的数字以​分三类情况,可以得到​的递推式。

同样的倒序递推就可以得到全部的值。

下面套用我们求解

套用快速幂和费马小定理求解逆元就出来答案了。

#include <bits/stdc++.h>
using namespace std;
#define int long long
typedef long long ll; typedef unsigned long long ull; typedef long double ld;
inline ll read() { ll s = 0, w = 1; char ch = getchar(); for (; !isdigit(ch); ch = getchar()) if (ch == '-') w = -1; for (; isdigit(ch); ch = getchar())    s = (s << 1) + (s << 3) + (ch ^ 48); return s * w; }
inline void print(ll x, int op = 10) { if (!x) { putchar('0'); if (op)    putchar(op); return; }    char F[40]; ll tmp = x > 0 ? x : -x;    if (x < 0)putchar('-');    int cnt = 0;    while (tmp > 0) { F[cnt++] = tmp % 10 + '0';        tmp /= 10; }    while (cnt > 0)putchar(F[--cnt]);    if (op)    putchar(op); }
ll qpow(ll a, ll b) { ll ans = 1;    while (b) { if (b & 1)    ans *= a;        b >>= 1;        a *= a; }    return ans; }    ll qpow(ll a, ll b, ll mod) { ll ans = 1; while (b) { if (b & 1)(ans *= a) %= mod; b >>= 1; (a *= a) %= mod; }return ans % mod; }
const int MOD = 998244353;

const int N = 100 + 7;
ll n, m;
int w[N], p[N], inv[N];
int f[N], g[N];

inline int add(int a, int b) {
    if (a + b >= MOD)    return a + b - MOD;
    return a + b;
}

int solve() {
    n = read();
    int sum = 0;
    for (int i = 1; i <= n; ++i) {
        w[i] = read();
        sum += w[i];
    }
    int tmp = qpow(sum % MOD, MOD - 2, MOD);
    for (int i = 1; i <= n; ++i) {
        p[i] = w[i] * tmp % MOD;
        inv[i] = qpow((1 - p[i] + MOD) % MOD, MOD - 2, MOD);
    }
    sum = 0;
    for (int i = n; i >= 1; --i) {
        f[i] = add(1, sum) * inv[i] % MOD;
        sum = add(sum, p[i] * f[i] % MOD);
    }
    sum = 0;
    for (int i = n; i >= 1; --i) {
        g[i] = add(add(1, 2 * p[i] * f[i] % MOD), sum) * inv[i] % MOD;
        sum = add(sum, p[i] * add(g[i], 2 * f[i] % MOD) % MOD);
    }
    int res = 0;
    for (int i = 1; i <= n; ++i) {
        res = add(res, add(g[i], (2 * f[i] + 1) % MOD) * p[i] % MOD);
    }
    print(res);

    return 1;
}

signed main() {
    //int T = read(); for (int i = 1; i <= T; ++i)
    {
        solve();
        //cout << (solve() ? "YES" : "NO") << endl;
    }
    return 0;
}