区间dp

区间dp有点道家的思想:一生二,二生三,三生万物。

区间dp的思想就是化繁为简,将大的区间化成两个小的区间,然后递推求解。

石子合并

石子归并

N堆石子摆成一条线。现要将石子有次序地合并成一堆。规定每次只能选相邻的2堆石子合并成新的一堆,并将新的一堆石子数记为该次合并的代价。计算将N堆石子合并成一堆的最小代价。

例如: 1 2 3 4,有不少合并方法

1 2 3 4 => 3 3 4(3) => 6 4(9) => 10(19)

1 2 3 4 => 1 5 4(5) => 1 9(14) => 10(24)

1 2 3 4 => 1 2 7(7) => 3 7(10) => 10(20)

括号里面为总代价可以看出,第一种方法的代价最低,现在给出n堆石子的数量,计算最小合并代价。

输入

第1行:N(2 <= N <= 100)
第2 - N + 1:N堆石子的数量(1 <= A[i] <= 10000)

输出

输出最小合并代价

输入样例

4

1 2 3 4

输出样例

19

这个是最经典的区间dp问题。

对于区间 [ i , j ] ,我们可以把它分成两个区间 [ i , k ] 和 [ k+1 , j ],那么我们就可以通过这两个小区间的值来计算出大区间的值。

这是归并的过程(1,2,3,4为编号):

可以发现的是,每次合并都是两个小区间合并成一个大区间,因此,我们只要计算出每个小区间所需要的最小花费,就能计算出合并后的大区间的最小花费。

代码如下:


#include<bits/stdc++.h>
using namespace std;
const int maxn=100+10;
const int INF=0x3f3f3f3f;
int n;
int dp[maxn][maxn];//表示合并区间[i,j]的最小花费
int sum[maxn];//从 1 ~ i 的和
int main()
{
    int x;
    scanf("%d",&n);
    for(int i=1;i<=n;i++){
        scanf("%d",&x);
        sum[i]=sum[i-1]+x;
    }
    memset(dp,INF,sizeof(dp));
    for(int i=1;i<=n;i++){
        dp[i][i]=0;
    }
    
    for(int len=2;len<=n;len++){//计算长为len的区间的值
        for(int i=1;i<=n;i++){
            //区间长为len,起点为i,终点为i+len-1
            int j=i+len-1;
            if(j>n) break;
            
            for(int k=i;k<j;k++){//枚举所有的小区间
                dp[i][j]=min(dp[i][j],dp[i][k]+dp[k+1][j]+sum[j]-sum[i-1]);//sum[j]-sum[i-1]表示合并大区间[i,j]所需要添加的花费
            }
        }
    }
    printf("%d\n",dp[1][n]);
    return 0;
}
这个代码的时间复杂度为O(n3),当n比较大时就没办法在规定时间内求解。

这是我们就可以用四边形不等式(别问,问就是百度)来优化。当我们求解dp[ i , j - 1 ]时,我们可以记录最小花费的 k 的位置,即s[ i ][ j-1 ]=k1。同样的,求解dp[ i + 1 ][ j ]时,我们也可以求出s[ i + 1 s][ j ] =k2

四边形不等式

如果对于任意的a1≤a2<b1≤b2,有m[ a1 , b1 ]+m[ a2 , b2 ]≤m[ a1 , b2 ]+m[ a2 , b1 ],那么m[ i , j ]满足四边形不等式。

写成符合这篇博客的形式为:

如果对于任意的 i ≤ i + 1 < j - 1 ≤ j ,有dp[ i , j-1 ]+dp[ i+1 , j ]≤dp[ i , j ]+dp[ i + 1 , j - 1 ],那么m[ i , j ]满足四边形不等式。

某聚的四边形不等式详细讲解

从这个定理,我们可以推出区间 [ i ][ j ]所对应的 k 值在s[ i ][ j - 1 ] 和s[ i + 1 ][ j ] 之间,然后,我们又可以计算出s[ i ][ j ],依次递推。

代码如下:

#include<bits/stdc++.h>
using namespace std;
const int maxn=100+10;
const int INF=0x3f3f3f3f;
int n;
int dp[maxn][maxn];
int s[maxn][maxn];
int sum[maxn];
int main()
{
    int x;
    scanf("%d",&n);
    for(int i=1;i<=n;i++){
        scanf("%d",&x);
        sum[i]=sum[i-1]+x;
    }
    memset(dp,INF,sizeof(dp));
    for(int i=1;i<=n;i++){
        dp[i][i]=0;
        s[i][i]=i;//初始化
    }
    for(int len=2;len<=n;len++){
        for(int i=1;i<=n;i++){
            int j=i+len-1;
            if(j>n) break;
            
            for(int k=s[i][j-1];k<=s[i+1][j];k++){
                if( dp[i][j] > dp[i][k] + dp[k+1][j] + sum[j] - sum[i-1]){
                    dp[i][j] = dp[i][k] + dp[k+1][j] + sum[j] - sum[i-1];
                    s[i][j]=k;//记录最小花费对应的k值
                }
            }
        }
    }
    printf("%d\n",dp[1][n]);
    return 0;
}
这样我们就能把复杂度从O(n3)减少到O(n2)。