题目链接
题目描述
给定一个包含 个正整数的数组和一个模数
。你需要从数组中选择一个子序列(可以是空序列),使得子序列中所有元素的和对
取模后的结果最大。
解题思路
-
问题分析与暴力解法
我们需要找到一个子序列,其和
sum
使得sum % m
最大。一个朴素的想法是枚举所有可能的子序列。一个长度为
的数组共有
个子序列。对于每个子序列,我们计算其和并取模,然后更新最大值。当
较大时(如此题中的
),
是一个天文数字,暴力枚举会严重超时。
-
折半搜索 (Meet-in-the-Middle)
当问题规模
在40左右,使得
无法接受,但
可以接受时,折半搜索是一种非常有效的优化策略。
核心思想是:将原问题分解为两个规模减半的独立子问题,分别求解,然后将两个子问题的解合并起来得到原问题的解。
-
算法步骤
-
分割数组:将原数组
分成两半:前半部分
a[0...n/2-1]
和后半部分a[n/2...n-1]
。 -
分别求解子问题:
- 对前半部分,通过 DFS 或二进制枚举,计算出所有可能的子序列和,并将这些和对
取模后的结果存入一个集合
sums1
。 - 同样地,对后半部分,计算出所有子序列和模
的结果,存入集合
sums2
。
- 对前半部分,通过 DFS 或二进制枚举,计算出所有可能的子序列和,并将这些和对
-
合并结果:
- 现在,原数组的任意一个子序列和
total_sum
,都可以看作是sums1
中的一个元素s1
与sums2
中的一个元素s2
的和,即total_sum = s1 + s2
。 - 我们的目标是最大化
(s1 + s2) % m
。 - 遍历
sums1
中的每一个元素s1
。对于每个s1
,我们希望在sums2
中找到一个s2
,使得(s1 + s2)
尽可能大但又不超过m-1
的某个倍数加上m-1
。 - 具体来说,对于一个
s1
,我们想找的s2
应该满足s1 + s2
接近m-1
或m-1 + m
等。- 情况一:
s1 + s2 < m
。为了让和最大,我们希望s2
尽可能大,但要满足s2 < m - s1
。即在sums2
中找到小于m - s1
的最大元素。 - 情况二:
s1 + s2 >= m
。这时(s1 + s2) % m = s1 + s2 - m
。为了让这个值最大,我们希望s1 + s2
最大,即s2
最大。
- 情况一:
- 将
sums2
排序后,我们可以通过二分查找来高效地完成上述寻找过程。 - 对于每个
s1
,我们需要寻找的目标值是target = m - 1 - s1
。- 在
sums2
中二分查找不大于target
的最大元素s2
。那么s1 + s2
是一个可能的答案。 - 同时,
sums1
和sums2
中的最大元素之和(s1 + s2_max) % m
也是一个可能的答案(对应情况二)。
- 在
- 在整个遍历过程中,维护一个全局最大值
max_mod_sum
。
- 现在,原数组的任意一个子序列和
-
-
优化合并
上面的合并逻辑可以简化:
- 将
sums2
排序。 - 遍历
sums1
中的每个元素s1
。 - 对于每个
s1
,我们想找一个s2 \in sums2
来最大化(s1 + s2) \% m
。 - 理想的目标和是
m-1
。因此,我们希望s2
能等于m - 1 - s1
。 - 在
sums2
中,二分查找target = m - 1 - s1
的上界 (upper_bound)。upper_bound
会返回第一个大于target
的元素的迭代器。- 如果我们将这个迭代器向前移动一位,就得到了
sums2
中小于或等于target
的最大元素。记为s2_candidate
。 - 那么
(s1 + s2_candidate) % m
就是一个潜在的最大值。
- 同时,别忘了
s1
和sums2
中最后一个(也就是最大)元素s2_last
的组合,即(s1 + s2_last) % m
。这也可能产生最大值(当s1+s2
跨越了的倍数时)。
- 在所有这些可能中取最大值即可。
- 将
代码
#include <iostream>
#include <vector>
#include <numeric>
#include <algorithm>
#include <set>
using namespace std;
void get_all_sums(int start, int end, const vector<long long>& a, int m, set<long long>& sums) {
int len = end - start;
for (int i = 0; i < (1 << len); ++i) {
long long current_sum = 0;
for (int j = 0; j < len; ++j) {
if ((i >> j) & 1) {
current_sum = (current_sum + a[start + j]) % m;
}
}
sums.insert(current_sum);
}
}
int main() {
ios_base::sync_with_stdio(false);
cin.tie(NULL);
int n;
long long m;
cin >> n >> m;
vector<long long> a(n);
for (int i = 0; i < n; ++i) {
cin >> a[i];
}
set<long long> sums1, sums2_set;
int mid = n / 2;
get_all_sums(0, mid, a, m, sums1);
get_all_sums(mid, n, a, m, sums2_set);
vector<long long> sums2(sums2_set.begin(), sums2_set.end());
long long max_mod_sum = 0;
for (long long s1 : sums1) {
// 寻找 s2 < m - s1 的最大 s2
long long target = m - 1 - s1;
auto it = upper_bound(sums2.begin(), sums2.end(), target);
if (it != sums2.begin()) {
--it;
max_mod_sum = max(max_mod_sum, (s1 + *it) % m);
}
// 考虑 s1 + s2 >= m 的情况
// s1 + sums2中最大的元素
max_mod_sum = max(max_mod_sum, (s1 + sums2.back()) % m);
}
cout << max_mod_sum << endl;
return 0;
}
import java.util.Scanner;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
public class Main {
private static void getAllSums(int start, int end, long[] a, long m, Set<Long> sums) {
int len = end - start;
for (int i = 0; i < (1 << len); i++) {
long currentSum = 0;
for (int j = 0; j < len; j++) {
if (((i >> j) & 1) == 1) {
currentSum = (currentSum + a[start + j]) % m;
}
}
sums.add(currentSum);
}
}
public static void main(String[] args) {
Scanner sc = new Scanner(System.in);
int n = sc.nextInt();
long m = sc.nextLong();
long[] a = new long[n];
for (int i = 0; i < n; i++) {
a[i] = sc.nextLong();
}
Set<Long> sums1_set = new HashSet<>();
Set<Long> sums2_set = new HashSet<>();
int mid = n / 2;
getAllSums(0, mid, a, m, sums1_set);
getAllSums(mid, n, a, m, sums2_set);
List<Long> sums1 = new ArrayList<>(sums1_set);
List<Long> sums2 = new ArrayList<>(sums2_set);
Collections.sort(sums2);
long maxModSum = 0;
for (long s1 : sums1) {
long target = m - 1 - s1;
int idx = Collections.binarySearch(sums2, target);
if (idx < 0) {
// (-insertion point - 1), so insertion point is -idx-1
idx = -idx - 2;
}
if (idx >= 0) {
maxModSum = Math.max(maxModSum, (s1 + sums2.get(idx)) % m);
}
maxModSum = Math.max(maxModSum, (s1 + sums2.get(sums2.size() - 1)) % m);
}
System.out.println(maxModSum);
}
}
import sys
from bisect import bisect_right
def get_all_sums(arr, m):
sums = {0}
for x in arr:
new_sums = set()
for s in sums:
new_sums.add((s + x) % m)
sums.update(new_sums)
return list(sums)
def solve():
line1 = sys.stdin.readline()
if not line1:
return
n, m = map(int, line1.split())
a = list(map(int, sys.stdin.readline().split()))
mid = n // 2
sums1 = get_all_sums(a[:mid], m)
sums2 = get_all_sums(a[mid:], m)
sums2.sort()
max_mod_sum = 0
for s1 in sums1:
target = m - 1 - s1
# Find index of the first element > target
idx = bisect_right(sums2, target)
if idx > 0:
# The element at idx-1 is <= target
s2_candidate = sums2[idx - 1]
max_mod_sum = max(max_mod_sum, (s1 + s2_candidate) % m)
# Also check with the largest element in sums2
if sums2: # ensure sums2 is not empty
max_mod_sum = max(max_mod_sum, (s1 + sums2[-1]) % m)
print(max_mod_sum)
solve()
算法及复杂度
- 算法:折半搜索 (Meet-in-the-Middle)
- 时间复杂度:
。生成两半子序列和的集合分别需要
。将其中一个集合排序需要
。最后,遍历第一个集合并对第二个集合进行二分查找,需要
。
- 空间复杂度:
,用于存储两半子序列和的集合。