牛客331194I多校 - The Great Wall II

题意

  • 给出一个长度为 n(n8000)n(n\leq 8000) 的序列,你需要将序列切割成 kk 段,每段对答案的贡献为这一段的最大值
  • 需要最小化答案
  • 你需要对于所有的 k[1,n]k \in [1,n],回答每个询问。

思路

初步思路

  • 考虑 DP。
  • dp[i][j]dp[i][j] 为前 ii 个物品,分成 jj 段,的最小值。
  • 转移:dp[i][j]min(dp[x][j1])+max(axai)dp[i][j] \leftarrow \min (dp[x][j-1])+\max(a_x \dots a_i)
  • 复杂度:O(n3)O(n^3)

优化

  • 观察转移方程:dp[i][j]min(dp[x][j1])+max(axai)dp[i][j] \leftarrow \min (dp[x][j-1])+\max(a_x \dots a_i)
    • 我们是否可以对于每个 dp[i][j]dp[i][j],使用单调栈快速地求出 min(dp[x][j1])\min (dp[x][j-1]) 实现 O(n2)O(n^2) 的转移?
    • 不可以。dp[x][j1]dp[x][j-1] 的确是最小的,但是加上后面的那个最大值,就不一定是最优的了。

    证明:这里我们分析区间 [x,i][x,i]。随着 xx 指针的向左移动,min(dp[x][j1])\min (dp[x][j-1]) 单调递减,但是 max(axai)\max(a_x \dots a_i) 单调递增。

  • 我们能观察到,在上面的转移方程中,随着 xx 指针的向左移动,min(dp[x][j1])\min (dp[x][j-1]) 单调递减,但是 max(axai)\max(a_x \dots a_i) 单调递增。
  • 所以对于一个 aia_i而言,在序列中存在一段区间 apaia_p \dots a_i,满足 max(apai)=ai\max (a_p \dots a_i)=a_i
  • 这段区间之前的部分,也就是 a1ap1a_1 \dots a_{p-1},满足 max(apai)>ai\max (a_p \dots a_i)> a_i
  • 转移:
    • 对于 j[p,i1]j\in [p,i-1]tmp1=min(dp[pi1])+aitmp1=\min(dp[p\dots i-1])+a_i
    • 对于 j[1,p1]j\in [1,p-1]tmp2=min(dp[1p1])tmp2=\min(dp[1\dots -p-1])
    • dp[i]=min(tmp1,tmp2)dp[i]=\min(tmp1,tmp2)
  • 以上转移可以用单调栈维护。
  • 既然要找分界点 pp,那么单调栈一定要维护当前 aia_i 的值,而且这个单调栈是由 aia_i 的单调递减的单调栈。
  • 显然,还要维护当前的 dpdpcur_dpcur\_dp,和之前的最小的 dpdppre_dppre\_dp
  • 转移:
    • 弹栈条件:stk.top.a<ai\text{stk.top.}a < a_i,弹栈的同时 min_dp=min(stk.top.cur_dp+ai)min\_dp=\min(\text{stk.top.}cur\_dp+a_i)
    • 弹栈停止意味着找到了分界点。此时在这之前意味着满足 max(apai)>ai\max (a_p \dots a_i)> a_i 的条件,那么 min_dp=min(min_dp,stk.top.pre_dp)min\_dp=\min(min\_dp,\text{stk.top.}pre\_dp)
    • 当前的 dpdp 值就是 min_dpmin\_dp
    • 入栈:cur_dp=min_dpcur\_dp=min\_dppredp=min(stk.top.pre_dp,cur_dp)pre_dp=\min(\text{stk.top.}pre\_dp,cur\_dp)

代码

#include <cstdio>
#include <iostream>
#include <stack>
#define int long long
const int N	= 8010;
const int INF	= 1e9;
using namespace std;

struct STK{int ai,dp,val;};
stack<STK> stk;
int dp[N][N], ai[N];
int n;


void Sol()
{
	for (int i=0; i<=n; i++)
	{
		for (int j=0; j<=n; j++)
		{
			dp[i][j] = INF;
		}
	}
	dp[0][0] = 0;
	
	for (int j=1; j<=n; j++)
	{
		while (!stk.empty())
			stk.pop();
		
		
		for (int i=j; i<=n; i++)
		{
			int min_dp = dp[i-1][j-1];
			
			while (!stk.empty() && stk.top().ai <= ai[i])
			{
				min_dp = min(min_dp, stk.top().dp);
				stk.pop();
			}
			
			int min_val = INF;
			if(!stk.empty())
				min_val = stk.top().val;
			
			dp[i][j] = min(min_val, min_dp+ai[i]);
			
			stk.push({ai[i], min_dp, dp[i][j]});
			
		}
	}
	
	for (int i=1; i<=n; i++)
	{
		printf("%lld\n",dp[n][i]);
	}
	
}

signed main()
{
	scanf("%lld",&n);
	for (int i=1; i<=n; i++)
	{
		scanf("%lld",&ai[i]);
	}
	Sol();
	
	return 0;
}