题号 NC19997
名称 [HAOI2016]字符合并
来源 [HAOI2016]

题目描述

有一个长度为 n 的 01 串,你可以每次将相邻的 k 个字符合并,得到一个新的字符并获得一定分数。得到的新字符和分数由这 k 个字符确定。你需要求出你能获得的最大分数。

样例

输入
3 2
101
1 10
1 10
0 20
1 30
输出
40
开始字符串为101
将前两个字符"10"合并成1,获得的10分,当前字符串为11
接着将”11“合成1,获得30分,当前字符串为1已经不能再合并了
总得分为40

算法

(区间dp + 状压dp)

​首先答案要求的是最大的收益,且分数w[i]都是正值,而如果一个字符串可以合并那么合并到不能再合并一定是最优的
看数据范围,将字符一部分合并我们很容易想到用区间dp,但是单纯的区间dp无法维护合并后的字符串的信息
我们观察到k的范围最多只有8,即是说最后得到的字符串的长度一定是小于等于7的,
所以我们多维护一维f[i][j][s]表示将区间[l,r]合并最后得到长度小于k的字符串s的最大收益,s可以用的二进制数表示。
我们将状态空间划分成两大种情况(以下的len为区间[l,r]的长度):

  • (len - 1) % (k - 1) + 1 = 1 (最后得到的字符串长度为1)的时候:(合成1个字符的情况有种,所以s的范围为),只考虑s的最后一位是由哪一个右区间合并而成的

    • 当区间len为1时:f[l][r][s] = 0
    • 当区间长度len>1时:我们从后往前枚举断点mid(取二进制的低位比取高位容易):f[l][r][c[s]] = max(f[l][r][c[s]],f[l][mid][s >> 1] + f[mid + 1][r][s & 1] + w[s])
  • (len - 1) % (k - 1) + 1 != 1(最后得到的字符串长度不为1)的时候:设m=(len - 1) % (k - 1) + 1(m!=1),m为s的长度,最后合成的字符串的情况只有种,只考虑s的最后一位是由哪一个右区间合并而成的

    • 同理我们从后往前枚举每一个断点:f[l][r][s] = max(f[l][r][s],f[l][mid][s >> 1] + f[mid + 1][r][s & 1])
  细节:

我们从后往前枚举断点的时候不需要一个一个枚举,因为我们只考虑s的最后一位字符是通过哪一个右区间合成的,
所以右区间的长度一定是满足( - 1) % (k - 1) + 1 = 1,
只需要枚举即可

(len - 1) % (k - 1) + 1 表示长度为len的字符串最后能得到的不可合并的字符串的长度

时间复杂度 :

​ 不太会计算,假设最差的情况n = 300,k = 8,那么计算次数n * n * (n / (k - 1)) * 2^k 约等于 987,428,571 ,超时了但实际不会达到这么多

C++ 代码

#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <unordered_map>
#include <map>
#include <vector>
#include <queue>
#include <set>
#include <bitset>
#include <cmath>

#define P 131

#define lc u << 1
#define rc u << 1 | 1

using namespace std;
typedef long long LL;
const int N = 1010;
const LL INF = 0x3f3f3f3f3f3f3f3fll;
LL f[310][310][(1 << 8) + 10];
int w[N],c[N];
char s[N];
int n,k,m;

void solve()
{
    scanf("%d%d",&n,&k);
    scanf("%s",s + 1);
    for(int i = 0;i < 1 << k;i ++) scanf("%d%d",&c[i],&w[i]);
    memset(f,-0x3f,sizeof f);
    for(int len = 1;len <= n;len ++)
    {
        int m = len;
        while(m >= k) m -= k - 1;
        for(int l = 1;l + len - 1 <= n;l ++)
        {
            int r = l + len - 1;
            if(len == 1) f[l][l][s[l] - '0'] = 0;
            else
            {
                if(m == 1)
                {
                    for(int x = 0;x < (1 << k);x ++)
                        for(int mid = r - 1;mid >= l;mid -= (k - 1))
                            f[l][r][c[x]] = max(f[l][r][c[x]],f[l][mid][x >> 1] + f[mid + 1][r][x & 1] + w[x]);
                }else
                {
                    for(int x = 0;x < (1 << m);x ++)
                        for(int mid = r - 1;mid >= l;mid -= (k - 1))
                            f[l][r][x] = max(f[l][r][x],f[l][mid][x >> 1] + f[mid + 1][r][x & 1]);
                }
            }
        }
    }
    LL res = -INF;
    for(int i = 0;i < 1 << k;i ++)
        res = max(res,f[1][n][i]);
    printf("%lld\n",res);
}

int main()
{
    #ifdef LOCAL
    freopen("in.txt", "r", stdin);
    freopen("out.txt", "w", stdout);
    #else
    #endif // LOCAL
    int T = 1;
    // init();
    // scanf("%d",&T);
    while(T --)
    {
        solve();
    }
    return 0;
}