牛客33785F - 到底是多少分啊

链接:https://ac.nowcoder.com/acm/contest/33785/F 知识点:组合数学,古典概型,计数DP,前缀和优化DP 难度:紫

题意

给出一个长度为 nn 的数列,现在需要执行以下的操作 KK 次:

选择数列中的某个数字,给它加上 11

操作完以后会产生很多不同的数列,现在从中随机挑选一个数列。 问该数列所有元素乘积的期望。

思路

错误思路:

考虑每种情况对答案的贡献,将贡献累加就是最终答案。 计算每种情况对答案的贡献太复杂,比如:先将 a1a_1 加上1,再将 a2a_2 加上1,和先将 a2a_2 加上1,再将 a1a_1 加上1。 最后的数列是一样的,但是你计算每种情况对答案的贡献不能去掉这种重复的情况。

正确思路:

只能考虑古典概型。 ans=ans=\frac{所有情况的总和}{情况总数}

情况总数怎么算?

错误思路

Cn+K1KC_{n+K-1}^{K} 这是从 nn 个元素中选出 KK 个,每种元素取的数量任意,取的方法总数。 错误原因:例如,先将 a1a_1 加上1,再将 a2a_2 加上3;和先将 a1a_1 加上3,再将 a2a_2 加上1。 这是两种不同的方法,而以上组合数只会将这样的算作一类。

正确思路

dp[i][j]dp[i][j] 代表对于前 i 个元素,额外加的数字的总和为 j ,的方法总数。 转移:dp[i][j]=dp[i1][0...j]dp[i][j]=dp[i-1][0...j] 转移复杂度:O(n3)O(n^3) 优化:显然可以预处理前缀和优化。 细节代码:

dp2[0][0] = 1;
for (int i=1; i<=n; i++)
{
    int sum=0;//前缀和
    for (int j=0; j<=K; j++)
    {
        sum += dp2[i-1][j], sum%=MOD;
        dp2[i][j] += sum, dp2[i][j]%=MOD;
    }
}

所有情况的总和怎么算?

n 和 K 都不大,考虑背包转移。 仔细看题,转移时,我们的限制是:额外加的数的总和为 K ,我们得出背包大小为 K。 dp[i][j]dp[i][j] 代表 对于前 i 个元素,我们额外加了 j 个数,的方法数目总和。 转移:dp[i][j]=(dp[i1][0...j]×(arr[i]+delta))dp[i][j]=\sum (dp[i-1][0...j]\times (arr[i]+delta)),其中 delta=j(0..j)delta=j-(0..j) 转移复杂度:O(n3)O(n^3) 优化:第一重循环是枚举 i,第二重是枚举 j,第三重是枚举 0...j0...j。显然我们要优化第三重循环。 一看到“++”和“\sum”,我们就考虑将加号拆开。

  1. 对于dp[i][j]=dp[i1][0...j]×arr[i]dp[i][j]=\sum dp[i-1][0...j]\times arr[i] 这一部分 显然可以前缀和优化。
dp[0][0] = 1;
for (int i=1; i<=n; i++)
{
    int sum=0;//前缀和
    for (int j=0; j<=K; j++)
    {
        sum += dp[i-1][j], sum%=MOD;   
        dp[i][j] = (sum*arr[i])%MOD, dp[i][j]%=MOD;
    }
}
  1. 对于 dp[i][j]=dp[i1][0...j]×deltadp[i][j]=\sum dp[i-1][0...j]\times delta 这部分, 我们写出 dp[i][j]dp[i][j]dp[i][j1]dp[i][j-1],再尝试错位相减。同样也可以前缀和优化。

代码

#include <cstdio>
#include <iostream>
#define int long long
const int N	= 3010;
const int MOD	= 998244353;
using namespace std;

long long POW(long long a,long long b)
{
	long long res=1;
	while(b)
	{
		if(b&1)
		{
			res*=a, res%=MOD;
		}
		b>>=1;
		a*=a, a%=MOD;
	}
	return res;
}

int dp[N][N],dp2[N][N];

int arr[N];
int n, K;

void Solve()
{
	dp[0][0] = 1;
	for (int i=1; i<=n; i++)
	{
		int tmp = 0;
		int sum=0;
		for (int j=0; j<=K; j++)
		{
			tmp += sum, tmp%=MOD;
			sum += dp[i-1][j], sum%=MOD;
			
			dp[i][j] = (sum*arr[i])%MOD + tmp, dp[i][j]%=MOD;
		}
	}
	
	dp2[0][0] = 1;
	for (int i=1; i<=n; i++)
	{
		int sum=0;
		for (int j=0; j<=K; j++)
		{
			sum += dp2[i-1][j], sum%=MOD;
			dp2[i][j] += sum, dp2[i][j]%=MOD;
		}
	}
	int ans = dp[n][K] * POW(dp2[n][K], MOD-2) % MOD;
	printf("%lld\n",ans);
}

signed main()
{
	scanf("%lld %lld",&n,&K);
	for (int i=1; i<=n; i++)
	{
		scanf("%lld",&arr[i]);
	}
	Solve();
	
	return 0;
}