牛客9510C - 排列

题意

定义超级逆序对为:满足 i<ji<jai>aj+1a_i>a_j+1 的二元组 (i,j)(i,j) 求长度为 nn 的,超级逆序对数量为 KK排列的数量 的倒数(分之一)。 1n5001\leq n\leq 5000K5000\leq K\leq 500

思路

先考虑普通逆序对

求长度为 nn 的,普通逆序对数量为 KK排列的数量

我们将 1,2,...,n1,2,...,n 从小到大依次放入序列。 假设对于当前数 ii, 如果我们将它放到序列的最后,那么对逆序对数量的贡献为 00, 如果我们将它放到序列的最前,那么对逆序对数量的贡献为 i1i-1, 也就是说,对于当前数 ii,放入当前序列,对逆序对数量的贡献的范围一定是 [0,i1][0,i-1]

我们同样也发现,对于当前数 ii,放入当前序列,有正好 ii 个空缺的部分可以插入,插到最后的贡献是 00,插到倒数第二的空处为 11,以此类推,插到最前为 i1i-1

  • 那我们是不是要把状态设计成 dp[i][j]dp[i][j] 代表,对于长度为 ii 的序列,第 ii 个数字插到位置 jj 的排列方案数?
  • 不能,注意看题,要求普通逆序对数量上限为 KK,因此我们的状态一定要表示出逆序对数量。

我们用 dp[i][j]dp[i][j] 代表, 对于长度为 ii 的序列,逆序对数量为 jj 的排列方案数。 dp[i][j]=dp[i1][jK...j0]dp[i][j]=\sum dp[i-1][j-K...j-0] 复杂度:O(n2)O(n^2),可以前缀和优化至 O(n2)O(n^2)

优化前
for (int i=1; i<=n; i++)
{
    for (int j=1; j<=K; j++)
    {
        for (int k=0; j-k>=0 && k<=i; k++)
        {
            dp[i][j] += dp[i-1][k];
        }
    }
}

优化后
for (int i=1; i<=n; i++)
{
    for (int j=1; j<=K; j++)
    {
        dp[i][j] = sum[i-1][j] - sum[i-1][ max(j-K-1, 0) ];
        sum[i][j] = sum[i-1][j] + dp[i][j];
    }
}

再考虑题中给出的超级逆序对。 既然要满足满足 i<ji<jai>aj+1a_i>a_j+1,也就是对于第 ii 个数,如果 ii 放在 i1i-1 前边,那么唯独这一对是不算的。也就是,在这里,我们要关心第 ii 个数字具体是放在哪里了。 我们用 dp[i][j][k]dp[i][j][k] 代表, 对于长度为 ii 的序列,超级逆序对数量为 jj,并且第 ii 个数放在 kk 的 排列方案数。

dp[i][j][k]=dp[i1][j(ik)][1...k1]dp[i][j][k]=\sum dp[i-1][j-(i-k)][1...k-1] dp[i][j][k]=dp[i1][j(ik)+1][k...i]dp[i][j][k]=\sum dp[i-1][j-(i-k)+1][k...i]

转移复杂度 O(n4)O(n^4) 前缀和优化到 O(n3)O(n^3)

代码

#include <cstdio>
#include <iostream>
#include <cstring>
#define int long long
const int N		= 501;
const int MOD	= 998244353;
using namespace std;

int sum[2][N][N];
int dp[2][N][N];
int n, K;

long long POW(long long a,long long b)
{
    long long ans=1;
    while(b>0)
    {
        if(b&1)
        {
            ans*=a;
            ans%=MOD;
        }
        a*=a;
        a%=MOD;
        b>>=1;
    }
	
	return ans;
}


int Cal(int x)
{
	return x*(x-1)/2;
}

void Solve()
{
	int cur = 1;
	dp[cur][0][0] = 1;
	sum[cur][0][0] = 1;
	for (int i=1; i<=n; i++)
	{
		cur^=1;
		memset(dp[cur], 0, sizeof(dp[cur]));
		memset(sum[cur], 0, sizeof(sum[cur]));
		for (int j=0; j<=K; j++)
		{
			for (int k=1; k<=i; k++)
			{
				if(j-(i-k) >= 0)
					dp[cur][j][k] += sum[cur^1][ j-(i-k) ][ k-1 ], dp[cur][j][k]%=MOD;
			
				if(j-(i-k)+1 >= 0)
				{
					dp[cur][j][k] += (sum[cur^1][ j-(i-k)+1 ][ i-1 ] - sum[cur^1][ j-(i-k)+1 ][ k-1 ] + MOD)%MOD;
					dp[cur][j][k]%=MOD;
				}
				sum[cur][j][k] += (sum[cur][j][k-1] + dp[cur][j][k])%MOD, sum[cur][j][k]%=MOD;
			}
		}
	}
	int ans=0;
	for (int i=0; i<=n; i++)
	{
		ans += dp[cur][K][i], ans%=MOD;
	}
	
	ans = 1 * POW(ans, MOD-2) % MOD;
	
	printf("%lld\n",ans);
}

signed main()
{
	scanf("%lld %lld",&n,&K);
	Solve();
	return 0;
}