题目链接

REAL739 小红的数组切割

题目描述

小红有一个长度为 的数组 和一个长度为 的字符串 。她最多可以将数组切割成 个连续的块。

数组的总权值为所有元素的权值之和。对于数组中的第 个元素,其权值计算方式为:。其中:

  • 的值取决于字符串 的第 个字符:若 ,则 ;若 ,则
  • 表示 所在的块的编号(从1开始)。

小红想通过合理的切割方式,使得数组的总权值最大。请你帮她计算出可能的最大权值。

思路分析

1. 拆分权值公式

总权值 ,其中 是元素 所在块的编号。 我们可以将公式展开:

  • 固定部分: 第一项 的值仅由输入的 决定,与如何切割无关。我们可以预先计算这部分的值,称之为 base_value
  • 可变部分: 第二项 的值取决于切割方案。我们的目标就是最大化这一部分,称之为 partition_value

2. 贪心策略

我们来分析“切割”操作对 partition_value 的影响。

  • 基准情况 (1个块): 所有元素都在第1块,此时 对所有 成立。partition_value
  • 增加一个切点: 假设我们在索引 c 之后增加一个切点(即在 之间切开)。这使得从 的所有元素的块编号都增加了1。这会给 partition_value 带来的增益

这个增益值可以通过 op 数组的后缀和 suffix_op 快速计算得到。在索引 c 之后切割的增益就是 suffix_op[c+1]

我们最多可以进行 次切割(以形成最多 个块)。为了最大化总权值,我们应该贪心地选择那些能带来正增益的切割点。

3. 算法步骤

  1. 根据字符串 生成 op 数组(1 或 -1)。
  2. 计算固定的 base_value = sum(op[i] * a[i])
  3. 计算 op 数组的后缀和 suffix_opsuffix_op[i] 表示
  4. partition_value 的基准值是 suffix_op[0] (对应1个块的情况)。
  5. 收集所有可能的切割增益:gains 列表包含 suffix_op[1], suffix_op[2], ..., suffix_op[n-1] 中所有大于0的项。
  6. gains 列表降序排序。
  7. 我们最多能进行 次切割,所以我们选取 gains 中前 min(k-1, len(gains)) 个最大的正增益,累加到 partition_value 上。
  8. 最终答案是 base_value + partition_value

所有求和过程都可能超过 int 范围,需要使用 long long

代码

#include <iostream>
#include <vector>
#include <string>
#include <numeric>
#include <algorithm>

using namespace std;

int main() {
    ios_base::sync_with_stdio(false);
    cin.tie(NULL);

    int n, k;
    cin >> n >> k;

    vector<long long> a(n);
    for (int i = 0; i < n; ++i) {
        cin >> a[i];
    }

    string s;
    cin >> s;

    vector<int> op(n);
    long long base_value = 0;
    for (int i = 0; i < n; ++i) {
        op[i] = (s[i] == '1' ? 1 : -1);
        base_value += op[i] * a[i];
    }

    vector<long long> suffix_op(n + 1, 0);
    for (int i = n - 1; i >= 0; --i) {
        suffix_op[i] = suffix_op[i+1] + op[i];
    }

    vector<long long> gains;
    // 增益来自在索引c(0 to n-2)后切割,值为suffix_op[c+1]
    // 对应于suffix_op数组的索引1到n-1
    for (int i = 1; i < n; ++i) {
        if (suffix_op[i] > 0) {
            gains.push_back(suffix_op[i]);
        }
    }

    sort(gains.rbegin(), gains.rend());

    long long partition_value = suffix_op[0];
    int cuts_to_make = min((int)gains.size(), k - 1);
    for (int i = 0; i < cuts_to_make; ++i) {
        partition_value += gains[i];
    }

    cout << base_value + partition_value << endl;

    return 0;
}
import java.util.Scanner;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;

public class Main {
    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);
        int n = sc.nextInt();
        int k = sc.nextInt();

        long[] a = new long[n];
        for (int i = 0; i < n; i++) {
            a[i] = sc.nextLong();
        }
        String s = sc.next();

        int[] op = new int[n];
        long baseValue = 0;
        for (int i = 0; i < n; i++) {
            op[i] = (s.charAt(i) == '1' ? 1 : -1);
            baseValue += op[i] * a[i];
        }

        long[] suffixOp = new long[n + 1];
        for (int i = n - 1; i >= 0; i--) {
            suffixOp[i] = suffixOp[i + 1] + op[i];
        }

        List<Long> gains = new ArrayList<>();
        for (int i = 1; i < n; i++) {
            if (suffixOp[i] > 0) {
                gains.add(suffixOp[i]);
            }
        }

        gains.sort(Collections.reverseOrder());

        long partitionValue = suffixOp[0];
        int cutsToMake = Math.min(gains.size(), k - 1);
        for (int i = 0; i < cutsToMake; i++) {
            partitionValue += gains.get(i);
        }

        System.out.println(baseValue + partitionValue);
    }
}
import sys

def solve():
    n, k = map(int, sys.stdin.readline().split())
    a = list(map(int, sys.stdin.readline().split()))
    s = sys.stdin.readline().strip()

    op = [1 if char == '1' else -1 for char in s]
    
    base_value = sum(op[i] * a[i] for i in range(n))

    suffix_op = [0] * (n + 1)
    for i in range(n - 1, -1, -1):
        suffix_op[i] = suffix_op[i + 1] + op[i]

    gains = []
    # 增益来自在索引c(0 to n-2)后切割,值为suffix_op[c+1]
    # 对应于suffix_op数组的索引1到n-1
    for i in range(1, n):
        if suffix_op[i] > 0:
            gains.append(suffix_op[i])
            
    gains.sort(reverse=True)

    partition_value = suffix_op[0]
    cuts_to_make = min(len(gains), k - 1)
    
    for i in range(cuts_to_make):
        partition_value += gains[i]
        
    print(base_value + partition_value)

solve()

算法及复杂度

  • 算法:贪心 + 前缀和/后缀和
  • 时间复杂度,其中 是数组的长度。计算 base_value 和后缀和都是 。主要的时间开销在于对增益列表 gains 的排序,其长度最多为
  • 空间复杂度,用于存储数组 aopsuffix_opgains