题号 NC19997
名称 [HAOI2016]字符合并
来源 [HAOI2016]
题目描述
有一个长度为 n 的 01 串,你可以每次将相邻的 k 个字符合并,得到一个新的字符并获得一定分数。得到的新字符和分数由这 k 个字符确定。你需要求出你能获得的最大分数。
样例
输入 3 2 101 1 10 1 10 0 20 1 30 输出 40
开始字符串为101 将前两个字符"10"合并成1,获得的10分,当前字符串为11 接着将”11“合成1,获得30分,当前字符串为1已经不能再合并了 总得分为40
算法
(区间dp + 状压dp)
首先答案要求的是最大的收益,且分数w[i]都是正值,而如果一个字符串可以合并那么合并到不能再合并一定是最优的
看数据范围,将字符一部分合并我们很容易想到用区间dp,但是单纯的区间dp无法维护合并后的字符串的信息
我们观察到k的范围最多只有8,即是说最后得到的字符串的长度一定是小于等于7的,
所以我们多维护一维f[i][j][s]表示将区间[l,r]合并最后得到长度小于k的字符串s的最大收益,s可以用的二进制数表示。
我们将状态空间划分成两大种情况(以下的len为区间[l,r]的长度):
(len - 1) % (k - 1) + 1 = 1 (最后得到的字符串长度为1)的时候:(合成1个字符的情况有
种,所以s的范围为
),只考虑s的最后一位是由哪一个右区间合并而成的
- 当区间len为1时:
f[l][r][s] = 0 - 当区间长度len>1时:我们从后往前枚举断点mid(取二进制的低位比取高位容易):
f[l][r][c[s]] = max(f[l][r][c[s]],f[l][mid][s >> 1] + f[mid + 1][r][s & 1] + w[s])
- 当区间len为1时:
(len - 1) % (k - 1) + 1 != 1(最后得到的字符串长度不为1)的时候:设m=(len - 1) % (k - 1) + 1(m!=1),m为s的长度,最后合成的字符串的情况只有
种,只考虑s的最后一位是由哪一个右区间合并而成的
- 同理我们从后往前枚举每一个断点:
f[l][r][s] = max(f[l][r][s],f[l][mid][s >> 1] + f[mid + 1][r][s & 1])
- 同理我们从后往前枚举每一个断点:
细节:
我们从后往前枚举断点的时候不需要一个一个枚举,因为我们只考虑s的最后一位字符是通过哪一个右区间合成的,
所以右区间的长度一定是满足(
- 1) % (k - 1) + 1 = 1,
只需要枚举即可
(len - 1) % (k - 1) + 1 表示长度为len的字符串最后能得到的不可合并的字符串的长度
时间复杂度 :
不太会计算,假设最差的情况n = 300,k = 8,那么计算次数n * n * (n / (k - 1)) * 2^k 约等于 987,428,571 ,超时了但实际不会达到这么多
C++ 代码
#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <unordered_map>
#include <map>
#include <vector>
#include <queue>
#include <set>
#include <bitset>
#include <cmath>
#define P 131
#define lc u << 1
#define rc u << 1 | 1
using namespace std;
typedef long long LL;
const int N = 1010;
const LL INF = 0x3f3f3f3f3f3f3f3fll;
LL f[310][310][(1 << 8) + 10];
int w[N],c[N];
char s[N];
int n,k,m;
void solve()
{
scanf("%d%d",&n,&k);
scanf("%s",s + 1);
for(int i = 0;i < 1 << k;i ++) scanf("%d%d",&c[i],&w[i]);
memset(f,-0x3f,sizeof f);
for(int len = 1;len <= n;len ++)
{
int m = len;
while(m >= k) m -= k - 1;
for(int l = 1;l + len - 1 <= n;l ++)
{
int r = l + len - 1;
if(len == 1) f[l][l][s[l] - '0'] = 0;
else
{
if(m == 1)
{
for(int x = 0;x < (1 << k);x ++)
for(int mid = r - 1;mid >= l;mid -= (k - 1))
f[l][r][c[x]] = max(f[l][r][c[x]],f[l][mid][x >> 1] + f[mid + 1][r][x & 1] + w[x]);
}else
{
for(int x = 0;x < (1 << m);x ++)
for(int mid = r - 1;mid >= l;mid -= (k - 1))
f[l][r][x] = max(f[l][r][x],f[l][mid][x >> 1] + f[mid + 1][r][x & 1]);
}
}
}
}
LL res = -INF;
for(int i = 0;i < 1 << k;i ++)
res = max(res,f[1][n][i]);
printf("%lld\n",res);
}
int main()
{
#ifdef LOCAL
freopen("in.txt", "r", stdin);
freopen("out.txt", "w", stdout);
#else
#endif // LOCAL
int T = 1;
// init();
// scanf("%d",&T);
while(T --)
{
solve();
}
return 0;
} 
京公网安备 11010502036488号