题目链接
题目描述
小红有一个长度为 的数组
和一个长度为
的字符串
。她最多可以将数组切割成
个连续的块。
数组的总权值为所有元素的权值之和。对于数组中的第 个元素,其权值计算方式为:
。其中:
的值取决于字符串
的第
个字符:若
,则
;若
,则
。
表示
所在的块的编号(从1开始)。
小红想通过合理的切割方式,使得数组的总权值最大。请你帮她计算出可能的最大权值。
思路分析
1. 拆分权值公式
总权值 ,其中
是元素
所在块的编号。
我们可以将公式展开:
- 固定部分: 第一项
的值仅由输入的
和
决定,与如何切割无关。我们可以预先计算这部分的值,称之为
base_value
。 - 可变部分: 第二项
的值取决于切割方案。我们的目标就是最大化这一部分,称之为
partition_value
。
2. 贪心策略
我们来分析“切割”操作对 partition_value
的影响。
- 基准情况 (1个块): 所有元素都在第1块,此时
对所有
成立。
partition_value
为。
- 增加一个切点: 假设我们在索引
c
之后增加一个切点(即在和
之间切开)。这使得从
到
的所有元素的块编号都增加了1。这会给
partition_value
带来的增益为。
这个增益值可以通过 op
数组的后缀和 suffix_op
快速计算得到。在索引 c
之后切割的增益就是 suffix_op[c+1]
。
我们最多可以进行 次切割(以形成最多
个块)。为了最大化总权值,我们应该贪心地选择那些能带来正增益的切割点。
3. 算法步骤
- 根据字符串
生成
op
数组(1 或 -1)。 - 计算固定的
base_value = sum(op[i] * a[i])
。 - 计算
op
数组的后缀和suffix_op
。suffix_op[i]
表示。
partition_value
的基准值是suffix_op[0]
(对应1个块的情况)。- 收集所有可能的切割增益:
gains
列表包含suffix_op[1], suffix_op[2], ..., suffix_op[n-1]
中所有大于0的项。 - 将
gains
列表降序排序。 - 我们最多能进行
次切割,所以我们选取
gains
中前min(k-1, len(gains))
个最大的正增益,累加到partition_value
上。 - 最终答案是
base_value + partition_value
。
所有求和过程都可能超过 int
范围,需要使用 long long
。
代码
#include <iostream>
#include <vector>
#include <string>
#include <numeric>
#include <algorithm>
using namespace std;
int main() {
ios_base::sync_with_stdio(false);
cin.tie(NULL);
int n, k;
cin >> n >> k;
vector<long long> a(n);
for (int i = 0; i < n; ++i) {
cin >> a[i];
}
string s;
cin >> s;
vector<int> op(n);
long long base_value = 0;
for (int i = 0; i < n; ++i) {
op[i] = (s[i] == '1' ? 1 : -1);
base_value += op[i] * a[i];
}
vector<long long> suffix_op(n + 1, 0);
for (int i = n - 1; i >= 0; --i) {
suffix_op[i] = suffix_op[i+1] + op[i];
}
vector<long long> gains;
// 增益来自在索引c(0 to n-2)后切割,值为suffix_op[c+1]
// 对应于suffix_op数组的索引1到n-1
for (int i = 1; i < n; ++i) {
if (suffix_op[i] > 0) {
gains.push_back(suffix_op[i]);
}
}
sort(gains.rbegin(), gains.rend());
long long partition_value = suffix_op[0];
int cuts_to_make = min((int)gains.size(), k - 1);
for (int i = 0; i < cuts_to_make; ++i) {
partition_value += gains[i];
}
cout << base_value + partition_value << endl;
return 0;
}
import java.util.Scanner;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
public class Main {
public static void main(String[] args) {
Scanner sc = new Scanner(System.in);
int n = sc.nextInt();
int k = sc.nextInt();
long[] a = new long[n];
for (int i = 0; i < n; i++) {
a[i] = sc.nextLong();
}
String s = sc.next();
int[] op = new int[n];
long baseValue = 0;
for (int i = 0; i < n; i++) {
op[i] = (s.charAt(i) == '1' ? 1 : -1);
baseValue += op[i] * a[i];
}
long[] suffixOp = new long[n + 1];
for (int i = n - 1; i >= 0; i--) {
suffixOp[i] = suffixOp[i + 1] + op[i];
}
List<Long> gains = new ArrayList<>();
for (int i = 1; i < n; i++) {
if (suffixOp[i] > 0) {
gains.add(suffixOp[i]);
}
}
gains.sort(Collections.reverseOrder());
long partitionValue = suffixOp[0];
int cutsToMake = Math.min(gains.size(), k - 1);
for (int i = 0; i < cutsToMake; i++) {
partitionValue += gains.get(i);
}
System.out.println(baseValue + partitionValue);
}
}
import sys
def solve():
n, k = map(int, sys.stdin.readline().split())
a = list(map(int, sys.stdin.readline().split()))
s = sys.stdin.readline().strip()
op = [1 if char == '1' else -1 for char in s]
base_value = sum(op[i] * a[i] for i in range(n))
suffix_op = [0] * (n + 1)
for i in range(n - 1, -1, -1):
suffix_op[i] = suffix_op[i + 1] + op[i]
gains = []
# 增益来自在索引c(0 to n-2)后切割,值为suffix_op[c+1]
# 对应于suffix_op数组的索引1到n-1
for i in range(1, n):
if suffix_op[i] > 0:
gains.append(suffix_op[i])
gains.sort(reverse=True)
partition_value = suffix_op[0]
cuts_to_make = min(len(gains), k - 1)
for i in range(cuts_to_make):
partition_value += gains[i]
print(base_value + partition_value)
solve()
算法及复杂度
- 算法:贪心 + 前缀和/后缀和
- 时间复杂度:
,其中
是数组的长度。计算
base_value
和后缀和都是。主要的时间开销在于对增益列表
gains
的排序,其长度最多为。
- 空间复杂度:
,用于存储数组
a
、op
、suffix_op
和gains
。