原来上一题是多校第二道期望题啊,我怎么没印象

题目链接:https://ac.nowcoder.com/acm/contest/11166/I

这题转移就比上题容易了些,但这题的状态比较难找到。
提取题目中的关键点:p是一个排列,每次选择的数需要大于所有已选的数。因为是一个排列,所以我们可以把当前选的数的大小作为状态
因此,我们可以尝试状态dp[i][j]:上一轮选择了i,前一轮选择了j时的期望轮数。
之后,据次得出状态转移式:


ps.这里式根据 (可能变成的状态的dp值和+单步贡献)*变成该状态的概率 值得到的。
具体看代码吧()

#include <bits/stdc++.h>
#define inf 0x3f3f3f3f
#define IOS ios_base::sync_with_stdio(0); cin.tie(0);
#define rep(i, a, n) for(int i = a; i <= n; ++ i)
#define per(i, a, n) for(int i = n; i >= a; -- i)
//#define ONLINE_JUDGE
using namespace std;
typedef long long ll;
const int mod=998244353;
template<typename T>void write(T x)
{
    if(x<0)
    {
        putchar('-');
        x=-x;
    }
    if(x>9)
    {
        write(x/10);
    }
    putchar(x%10+'0');
}

template<typename T> void read(T &x)
{
    x = 0;char ch = getchar();ll f = 1;
    while(!isdigit(ch)){if(ch == '-')f*=-1;ch=getchar();}
    while(isdigit(ch)){x = x*10+ch-48;ch=getchar();}x*=f;
}

int gcd(int a,int b){return b==0?a:gcd(b,a%b);}
int lcm(int a,int b){return a/gcd(a,b)*b;};
ll ksm(ll a,ll n){
    ll ans=1;
    while(n){
        if(n&1) ans=(ans*a)%mod;
        a=a*a%mod;
        n>>=1;
    }
    return ans%mod;
}
//==============================================================
const int maxn=5e3+10;
#define int ll
int n,p[maxn],q[maxn];
int cnt[maxn],sum[maxn];
int Inv[maxn];
int dp[maxn][maxn];
int inv(int x){
    return Inv[x];
}

signed main()
{
    #ifndef ONLINE_JUDGE
    freopen("in.txt","r",stdin);
    freopen("out.txt","w",stdout);
    #endif
    //clock_t c1 = clock();
    //===========================================================
    read(n);
    for(int i=0;i<maxn;++i)Inv[i]=ksm(i,mod-2);
    for(int i=1;i<=n;++i)read(p[i]),q[p[i]]=i;
    for(int i=n;i>=1;--i){
        memset(cnt,0,sizeof(cnt));
        memset(sum,0,sizeof(sum));
        for(int j=i+1;j<=n;++j){
            sum[q[j]]+=dp[i][j];
            sum[q[j]]%=mod;
            cnt[q[j]]+=1;
        }
        for(int j=n-1;j>=0;--j){
            sum[j]+=sum[j+1];
            sum[j]%=mod;
            cnt[j]+=cnt[j+1];
        }
        /* for(int j=0;j<=n;++j){
            cerr<<sum[j]<<" "<<cnt[j]<<endl;
        }
        cerr<<"================================"<<endl; */
        for(int j=0;j<i;++j){
            int num=cnt[q[j]];
            int s=sum[q[j]];
            if(num){
                dp[j][i]=(dp[j][i]+s*inv(num)%mod+1)%mod;
            }
        }
    }
    int res=0;
    for(int i=1;i<=n;++i){
        res+=dp[0][i];
        res%=mod;
    }
    res=(res*inv(n))%mod;
    res=(res+1)%mod;
    cout<<res<<endl;
    //===========================================================
    //std::cerr << "Time:" << clock() - c1 << "ms" << std::endl;
    return 0;
}