题目链接

模意义下最大子序列和(Hard Version)

题目描述

给定一个包含 个正整数的数组和一个模数 。你需要从数组中选择一个子序列(可以是空序列),使得子序列中所有元素的和对 取模后的结果最大。

解题思路

  1. 问题分析与暴力解法

    我们需要找到一个子序列,其和 sum 使得 sum % m 最大。

    一个朴素的想法是枚举所有可能的子序列。一个长度为 的数组共有 个子序列。对于每个子序列,我们计算其和并取模,然后更新最大值。当 较大时(如此题中的 ), 是一个天文数字,暴力枚举会严重超时。

  2. 折半搜索 (Meet-in-the-Middle)

    当问题规模 在40左右,使得 无法接受,但 可以接受时,折半搜索是一种非常有效的优化策略。

    核心思想是:将原问题分解为两个规模减半的独立子问题,分别求解,然后将两个子问题的解合并起来得到原问题的解。

  3. 算法步骤

    1. 分割数组:将原数组 分成两半:前半部分 a[0...n/2-1] 和后半部分 a[n/2...n-1]

    2. 分别求解子问题

      • 对前半部分,通过 DFS 或二进制枚举,计算出所有可能的子序列和,并将这些和对 取模后的结果存入一个集合 sums1
      • 同样地,对后半部分,计算出所有子序列和模 的结果,存入集合 sums2
    3. 合并结果

      • 现在,原数组的任意一个子序列和 total_sum,都可以看作是 sums1 中的一个元素 s1sums2 中的一个元素 s2 的和,即 total_sum = s1 + s2
      • 我们的目标是最大化 (s1 + s2) % m
      • 遍历 sums1 中的每一个元素 s1。对于每个 s1,我们希望在 sums2 中找到一个 s2,使得 (s1 + s2) 尽可能大但又不超过 m-1 的某个倍数加上 m-1
      • 具体来说,对于一个 s1,我们想找的 s2 应该满足 s1 + s2 接近 m-1m-1 + m 等。
        • 情况一s1 + s2 < m。为了让和最大,我们希望 s2 尽可能大,但要满足 s2 < m - s1。即在 sums2 中找到小于 m - s1 的最大元素。
        • 情况二s1 + s2 >= m。这时 (s1 + s2) % m = s1 + s2 - m。为了让这个值最大,我们希望 s1 + s2 最大,即 s2 最大。
      • sums2 排序后,我们可以通过二分查找来高效地完成上述寻找过程。
      • 对于每个 s1,我们需要寻找的目标值是 target = m - 1 - s1
        • sums2 中二分查找不大于 target 的最大元素 s2。那么 s1 + s2 是一个可能的答案。
        • 同时,sums1sums2 中的最大元素之和 (s1 + s2_max) % m 也是一个可能的答案(对应情况二)。
      • 在整个遍历过程中,维护一个全局最大值 max_mod_sum
  4. 优化合并

    上面的合并逻辑可以简化:

    1. sums2 排序。
    2. 遍历 sums1 中的每个元素 s1
    3. 对于每个 s1,我们想找一个 s2 \in sums2 来最大化 (s1 + s2) \% m
    4. 理想的目标和是 m-1。因此,我们希望 s2 能等于 m - 1 - s1
    5. sums2 中,二分查找 target = m - 1 - s1上界 (upper_bound)
      • upper_bound 会返回第一个大于 target 的元素的迭代器。
      • 如果我们将这个迭代器向前移动一位,就得到了 sums2小于或等于 target 的最大元素。记为 s2_candidate
      • 那么 (s1 + s2_candidate) % m 就是一个潜在的最大值。
    6. 同时,别忘了 s1sums2 中最后一个(也就是最大)元素 s2_last 的组合,即 (s1 + s2_last) % m。这也可能产生最大值(当 s1+s2 跨越了 的倍数时)。
    7. 在所有这些可能中取最大值即可。

代码

#include <iostream>
#include <vector>
#include <numeric>
#include <algorithm>
#include <set>

using namespace std;

void get_all_sums(int start, int end, const vector<long long>& a, int m, set<long long>& sums) {
    int len = end - start;
    for (int i = 0; i < (1 << len); ++i) {
        long long current_sum = 0;
        for (int j = 0; j < len; ++j) {
            if ((i >> j) & 1) {
                current_sum = (current_sum + a[start + j]) % m;
            }
        }
        sums.insert(current_sum);
    }
}

int main() {
    ios_base::sync_with_stdio(false);
    cin.tie(NULL);

    int n;
    long long m;
    cin >> n >> m;

    vector<long long> a(n);
    for (int i = 0; i < n; ++i) {
        cin >> a[i];
    }

    set<long long> sums1, sums2_set;
    int mid = n / 2;
    get_all_sums(0, mid, a, m, sums1);
    get_all_sums(mid, n, a, m, sums2_set);
    
    vector<long long> sums2(sums2_set.begin(), sums2_set.end());

    long long max_mod_sum = 0;

    for (long long s1 : sums1) {
        // 寻找 s2 < m - s1 的最大 s2
        long long target = m - 1 - s1;
        auto it = upper_bound(sums2.begin(), sums2.end(), target);
        
        if (it != sums2.begin()) {
            --it;
            max_mod_sum = max(max_mod_sum, (s1 + *it) % m);
        }
        
        // 考虑 s1 + s2 >= m 的情况
        // s1 + sums2中最大的元素
        max_mod_sum = max(max_mod_sum, (s1 + sums2.back()) % m);
    }

    cout << max_mod_sum << endl;

    return 0;
}
import java.util.Scanner;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Set;

public class Main {
    private static void getAllSums(int start, int end, long[] a, long m, Set<Long> sums) {
        int len = end - start;
        for (int i = 0; i < (1 << len); i++) {
            long currentSum = 0;
            for (int j = 0; j < len; j++) {
                if (((i >> j) & 1) == 1) {
                    currentSum = (currentSum + a[start + j]) % m;
                }
            }
            sums.add(currentSum);
        }
    }

    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);
        int n = sc.nextInt();
        long m = sc.nextLong();
        long[] a = new long[n];
        for (int i = 0; i < n; i++) {
            a[i] = sc.nextLong();
        }

        Set<Long> sums1_set = new HashSet<>();
        Set<Long> sums2_set = new HashSet<>();
        int mid = n / 2;
        getAllSums(0, mid, a, m, sums1_set);
        getAllSums(mid, n, a, m, sums2_set);
        
        List<Long> sums1 = new ArrayList<>(sums1_set);
        List<Long> sums2 = new ArrayList<>(sums2_set);
        Collections.sort(sums2);
        
        long maxModSum = 0;

        for (long s1 : sums1) {
            long target = m - 1 - s1;
            
            int idx = Collections.binarySearch(sums2, target);
            if (idx < 0) {
                // (-insertion point - 1), so insertion point is -idx-1
                idx = -idx - 2;
            }
            
            if (idx >= 0) {
                maxModSum = Math.max(maxModSum, (s1 + sums2.get(idx)) % m);
            }

            maxModSum = Math.max(maxModSum, (s1 + sums2.get(sums2.size() - 1)) % m);
        }

        System.out.println(maxModSum);
    }
}
import sys
from bisect import bisect_right

def get_all_sums(arr, m):
    sums = {0}
    for x in arr:
        new_sums = set()
        for s in sums:
            new_sums.add((s + x) % m)
        sums.update(new_sums)
    return list(sums)

def solve():
    line1 = sys.stdin.readline()
    if not line1:
        return
    n, m = map(int, line1.split())
    a = list(map(int, sys.stdin.readline().split()))

    mid = n // 2
    sums1 = get_all_sums(a[:mid], m)
    sums2 = get_all_sums(a[mid:], m)
    sums2.sort()
    
    max_mod_sum = 0
    
    for s1 in sums1:
        target = m - 1 - s1
        
        # Find index of the first element > target
        idx = bisect_right(sums2, target)
        
        if idx > 0:
            # The element at idx-1 is <= target
            s2_candidate = sums2[idx - 1]
            max_mod_sum = max(max_mod_sum, (s1 + s2_candidate) % m)
            
        # Also check with the largest element in sums2
        if sums2: # ensure sums2 is not empty
             max_mod_sum = max(max_mod_sum, (s1 + sums2[-1]) % m)

    print(max_mod_sum)

solve()

算法及复杂度

  • 算法:折半搜索 (Meet-in-the-Middle)
  • 时间复杂度。生成两半子序列和的集合分别需要 。将其中一个集合排序需要 。最后,遍历第一个集合并对第二个集合进行二分查找,需要
  • 空间复杂度,用于存储两半子序列和的集合。