思路
方法一
状态方程:
状态方程中的表示的是数组的下标,即位置。
,表示上一步选了,上一步选了,这一轮是选择,从上一轮开始游戏进行的回合数的期望(即从选开始计数);
,表示上一步选了,上一步选了,这一轮是选择,从上一轮开始游戏进行的回合数的期望(即从选开始计数)
状态转移方程:
表示事件发生的概率
运行超时的代码:code
考虑一下优化:
首先记忆化搜索保证只会对种情况进行搜索,可以考虑把每种情况里面的一层循环给消掉。
1: if(s) { int cnt=c[p2+1][a[p1]]; for(int i=p2+1;i<=n;++i) if(a[i]>a[p1]) ans=(ans+(dfs(p1,i,0)+1)*inv[cnt]%mod)%mod; } 2: if(s){ int cnt=c[p2+1][a[p1]]; if(a[p1]>a[p2]){ ans=qu(ans+qu(qu(dfs(p1,p2+1,0)*inv[cnt])+1ll)); } ans=qu(ans+dfs(p1+1,p2,1)); }
代码中当时才能产生新的情况。如果记,调用求时,当时,表明这个状态合法,这时才会进行 新的状态调用,接着会套娃调用。就可以成功消掉一层循环了。
列如,下面是调用后的部分过程
dfs(p1+1,p2,1);//a[p2]>a[p1] dfs(p1,p2+1,0); dfs(p1+1+1,p2,1)//a[p2]>a[p1+1+1] dfs(p1+1,p2+1,0); ...
假如,那么答案其实就是,表示选选,选择轮到来选,进行回合数的期望(包括选的这一回合),直接写成就行了。
code:
#include<bits/stdc++.h> using namespace std; typedef long long ll; const int maxn=5002,mod=998244353; int n,a[maxn],c[maxn][maxn]; int inv[maxn]; int dp[maxn][maxn][2]; ll dfs(int p1,int p2,int s){ if(p1>n||p2>n) return 0; if(~dp[p1][p2][s]) return dp[p1][p2][s]; ll ans=0; if(s==0){ int cnt=c[p1+1][a[p2]]; if(a[p2]>a[p1]) ans=(ans+((dfs(p1+1,p2,1)*inv[cnt])%mod+1ll)%mod)%mod; ans=(ans+dfs(p1,p2+1,0))%mod; }else{ int cnt=c[p2+1][a[p1]]; if(a[p1]>a[p2]) ans=(ans+((dfs(p1,p2+1,0)*inv[cnt])%mod+1ll)%mod)%mod; ans=(ans+dfs(p1+1,p2,1))%mod; } return dp[p1][p2][s]=ans; } int main(){ scanf("%d",&n); memset(dp,-1,sizeof dp); for(int i=1;i<=n;i++) scanf("%d",&a[i]); inv[1] = 1;for(int i = 2;i <= n;i++) inv[i] = 1ll * (mod - mod / i) * inv[mod % i] % mod; for(int i=n;i>=1;--i){ for(int j=1;j<=n;++j){ c[i][j]=c[i+1][j]; if(j<a[i]) c[i][j]++; } } ll res=0; for(int i=1;i<=n;i++){ int cnt=c[1][a[i]]; res=(res+(((dfs(i,1,0)*inv[cnt])+1ll)%mod*inv[n])%mod)%mod; //优化前的代码:res=(res+(dfs(i,0,1)+1)*inv[n]%mod)%mod; } printf("%lld\n",res); }
方法二
思路:
设表示一个人上一次选的是另一个人上一次选的是,接下来游戏结束的期望值。这里表示的是值,不是下标。
那么就有:
从大到小枚举值,接着从大到小枚举位置。如果,记录的前缀和以及个数,否则更新的值。即:
#include <bits/stdc++.h> using namespace std; const int mod=998244353; int n,a[5005],inv[5005],f[5005][5005]; int main() { scanf("%d",&n); inv[1] = 1; for(int i = 2; i <= n; i++) inv[i] = 1ll * (mod - mod / i) * inv[mod % i] % mod; for(int i=1; i<=n; i++)scanf("%d",&a[i]); for(int j=n; ~j; j--) { int sum=0,cnt=0; for(int i=n; ~i; i--) { if(a[i]==j) continue; if(a[i]>j) cnt++,sum=(sum+f[j][a[i]])%mod; else f[a[i]][j]=(1ll*sum*inv[cnt]+1)%mod; } } int ans = 0; for(int i = 1; i <= n; i++) ans = (ans + f[0][i]) % mod; printf("%d", 1ll * ans * inv[n] % mod); return 0; }