题号 NC19997
名称 [HAOI2016]字符合并
来源 [HAOI2016]
题目描述
有一个长度为 n 的 01 串,你可以每次将相邻的 k 个字符合并,得到一个新的字符并获得一定分数。得到的新字符和分数由这 k 个字符确定。你需要求出你能获得的最大分数。
样例
输入 3 2 101 1 10 1 10 0 20 1 30 输出 40
开始字符串为101 将前两个字符"10"合并成1,获得的10分,当前字符串为11 接着将”11“合成1,获得30分,当前字符串为1已经不能再合并了 总得分为40
算法
(区间dp + 状压dp)
首先答案要求的是最大的收益,且分数w[i]
都是正值,而如果一个字符串可以合并那么合并到不能再合并一定是最优的
看数据范围,将字符一部分合并我们很容易想到用区间dp,但是单纯的区间dp无法维护合并后的字符串的信息
我们观察到k的范围最多只有8,即是说最后得到的字符串的长度一定是小于等于7的,
所以我们多维护一维f[i][j][s]
表示将区间[l,r]
合并最后得到长度小于k的字符串s的最大收益,s可以用的二进制数表示。
我们将状态空间划分成两大种情况(以下的len为区间[l,r]的长度):
(len - 1) % (k - 1) + 1 = 1 (最后得到的字符串长度为1)的时候:(合成1个字符的情况有种,所以s的范围为),只考虑s的最后一位是由哪一个右区间合并而成的
- 当区间len为1时:
f[l][r][s] = 0
- 当区间长度len>1时:我们从后往前枚举断点mid(取二进制的低位比取高位容易):
f[l][r][c[s]] = max(f[l][r][c[s]],f[l][mid][s >> 1] + f[mid + 1][r][s & 1] + w[s])
- 当区间len为1时:
(len - 1) % (k - 1) + 1 != 1(最后得到的字符串长度不为1)的时候:设m=(len - 1) % (k - 1) + 1(m!=1),m为s的长度,最后合成的字符串的情况只有种,只考虑s的最后一位是由哪一个右区间合并而成的
- 同理我们从后往前枚举每一个断点:
f[l][r][s] = max(f[l][r][s],f[l][mid][s >> 1] + f[mid + 1][r][s & 1])
- 同理我们从后往前枚举每一个断点:
细节:
我们从后往前枚举断点的时候不需要一个一个枚举,因为我们只考虑s的最后一位字符是通过哪一个右区间合成的,
所以右区间的长度一定是满足( - 1) % (k - 1) + 1 = 1,
只需要枚举即可
(len - 1) % (k - 1) + 1 表示长度为len的字符串最后能得到的不可合并的字符串的长度
时间复杂度 :
不太会计算,假设最差的情况n = 300,k = 8,那么计算次数n * n * (n / (k - 1)) * 2^k 约等于 987,428,571 ,超时了但实际不会达到这么多
C++ 代码
#include <iostream> #include <cstdio> #include <cstring> #include <algorithm> #include <unordered_map> #include <map> #include <vector> #include <queue> #include <set> #include <bitset> #include <cmath> #define P 131 #define lc u << 1 #define rc u << 1 | 1 using namespace std; typedef long long LL; const int N = 1010; const LL INF = 0x3f3f3f3f3f3f3f3fll; LL f[310][310][(1 << 8) + 10]; int w[N],c[N]; char s[N]; int n,k,m; void solve() { scanf("%d%d",&n,&k); scanf("%s",s + 1); for(int i = 0;i < 1 << k;i ++) scanf("%d%d",&c[i],&w[i]); memset(f,-0x3f,sizeof f); for(int len = 1;len <= n;len ++) { int m = len; while(m >= k) m -= k - 1; for(int l = 1;l + len - 1 <= n;l ++) { int r = l + len - 1; if(len == 1) f[l][l][s[l] - '0'] = 0; else { if(m == 1) { for(int x = 0;x < (1 << k);x ++) for(int mid = r - 1;mid >= l;mid -= (k - 1)) f[l][r][c[x]] = max(f[l][r][c[x]],f[l][mid][x >> 1] + f[mid + 1][r][x & 1] + w[x]); }else { for(int x = 0;x < (1 << m);x ++) for(int mid = r - 1;mid >= l;mid -= (k - 1)) f[l][r][x] = max(f[l][r][x],f[l][mid][x >> 1] + f[mid + 1][r][x & 1]); } } } } LL res = -INF; for(int i = 0;i < 1 << k;i ++) res = max(res,f[1][n][i]); printf("%lld\n",res); } int main() { #ifdef LOCAL freopen("in.txt", "r", stdin); freopen("out.txt", "w", stdout); #else #endif // LOCAL int T = 1; // init(); // scanf("%d",&T); while(T --) { solve(); } return 0; }