思路
方法一
状态方程:
状态方程中的表示的是数组的下标,即位置。
,表示
上一步选了
,
上一步选了
,这一轮是
选择,从上一轮开始游戏进行的回合数的期望(即从
选
开始计数);
,表示
上一步选了
,
上一步选了
,这一轮是
选择,从上一轮开始游戏进行的回合数的期望(即从
选
开始计数)
状态转移方程:
表示事件发生的概率
运行超时的代码:code
考虑一下优化:
首先记忆化搜索保证只会对种情况进行搜索,可以考虑把每种情况里面的一层
循环给消掉。
1:
if(s) {
int cnt=c[p2+1][a[p1]];
for(int i=p2+1;i<=n;++i) if(a[i]>a[p1])
ans=(ans+(dfs(p1,i,0)+1)*inv[cnt]%mod)%mod;
}
2:
if(s){
int cnt=c[p2+1][a[p1]];
if(a[p1]>a[p2]){
ans=qu(ans+qu(qu(dfs(p1,p2+1,0)*inv[cnt])+1ll));
}
ans=qu(ans+dfs(p1+1,p2,1));
} 代码中当
时才能产生新的情况。如果记
,调用
求
时,当
时,表明这个状态合法,这时才会进行 新的状态调用
,接着会套娃调用
。就可以成功消掉一层循环了。
列如,下面是调用后的部分过程
dfs(p1+1,p2,1);//a[p2]>a[p1] dfs(p1,p2+1,0); dfs(p1+1+1,p2,1)//a[p2]>a[p1+1+1] dfs(p1+1,p2+1,0); ...
假如,那么答案其实就是
,
表示
选
选
,选择轮到
来选,进行回合数的期望(包括
选
的这一回合),直接写成
就行了。
code:
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int maxn=5002,mod=998244353;
int n,a[maxn],c[maxn][maxn];
int inv[maxn];
int dp[maxn][maxn][2];
ll dfs(int p1,int p2,int s){
if(p1>n||p2>n) return 0;
if(~dp[p1][p2][s]) return dp[p1][p2][s];
ll ans=0;
if(s==0){
int cnt=c[p1+1][a[p2]];
if(a[p2]>a[p1])
ans=(ans+((dfs(p1+1,p2,1)*inv[cnt])%mod+1ll)%mod)%mod;
ans=(ans+dfs(p1,p2+1,0))%mod;
}else{
int cnt=c[p2+1][a[p1]];
if(a[p1]>a[p2])
ans=(ans+((dfs(p1,p2+1,0)*inv[cnt])%mod+1ll)%mod)%mod;
ans=(ans+dfs(p1+1,p2,1))%mod;
}
return dp[p1][p2][s]=ans;
}
int main(){
scanf("%d",&n);
memset(dp,-1,sizeof dp);
for(int i=1;i<=n;i++) scanf("%d",&a[i]);
inv[1] = 1;for(int i = 2;i <= n;i++) inv[i] = 1ll * (mod - mod / i) * inv[mod % i] % mod;
for(int i=n;i>=1;--i){
for(int j=1;j<=n;++j){
c[i][j]=c[i+1][j];
if(j<a[i]) c[i][j]++;
}
}
ll res=0;
for(int i=1;i<=n;i++){
int cnt=c[1][a[i]];
res=(res+(((dfs(i,1,0)*inv[cnt])+1ll)%mod*inv[n])%mod)%mod;
//优化前的代码:res=(res+(dfs(i,0,1)+1)*inv[n]%mod)%mod;
}
printf("%lld\n",res);
} 方法二
思路:
设表示一个人上一次选的是
另一个人上一次选的是
,接下来游戏结束的期望值。这里
表示的是值,不是下标。
那么就有:
从大到小枚举值,接着从大到小枚举位置
。如果
,记录
的前缀和以及个数,否则更新
的值。即:
#include <bits/stdc++.h>
using namespace std;
const int mod=998244353;
int n,a[5005],inv[5005],f[5005][5005];
int main() {
scanf("%d",&n);
inv[1] = 1;
for(int i = 2; i <= n; i++) inv[i] = 1ll * (mod - mod / i) * inv[mod % i] % mod;
for(int i=1; i<=n; i++)scanf("%d",&a[i]);
for(int j=n; ~j; j--) {
int sum=0,cnt=0;
for(int i=n; ~i; i--) {
if(a[i]==j) continue;
if(a[i]>j) cnt++,sum=(sum+f[j][a[i]])%mod;
else f[a[i]][j]=(1ll*sum*inv[cnt]+1)%mod;
}
}
int ans = 0;
for(int i = 1; i <= n; i++) ans = (ans + f[0][i]) % mod;
printf("%d", 1ll * ans * inv[n] % mod);
return 0;
}

京公网安备 11010502036488号