题目链接
题目描述
给定一个 的整数矩阵。要求从四列中各选一个数,使得这四个数的和为 0。请计算共有多少种不同的选法组合。
形式化地,给定四个长度为
的数组
,找出满足
的四元组
的数量,其中
的范围均为
。
解题思路
一个朴素的解法是使用四重循环遍历所有可能的下标组合 ,检查它们的和是否为 0。这种方法的时间复杂度为
,会严重超时。
为了优化算法,我们可以对原方程 进行变形,得到
。
这个形式启发我们使用折半枚举(Meet-in-the-Middle) 的思想。之前我们尝试了“排序+二分查找”(超时)和“哈希表”(内存超限)的方法。为了在时间和空间上取得平衡,我们采用第三种经典方法:折半枚举 + 双指针。
- 预处理:我们首先用两重循环计算出两组和:
sums_ab
:包含所有个
的和。
sums_cd
:包含所有个
的和。
- 排序:我们将
sums_ab
按升序排序,并将sums_cd
也按升序排序。 - 双指针扫描:
- 我们设置两个指针:
ptr1
指向sums_ab
的开头(最小值),ptr2
指向sums_cd
的末尾(最大值)。 - 我们移动指针,计算
current_sum = sums_ab[ptr1] + sums_cd[ptr2]
:- 若
current_sum == 0
:我们找到了一个解。由于数组中可能存在重复元素,我们需要统计当前sums_ab[ptr1]
和sums_cd[ptr2]
各自连续出现的次数(count1
和count2
)。然后将count1 * count2
累加到总答案中,并把两个指针分别移动到下一个不重复的元素上。 - 若
current_sum < 0
:和太小,需要增大sums_ab
的值,因此ptr1
右移。 - 若
current_sum > 0
:和太大,需要减小sums_cd
的值,因此ptr2
左移。
- 若
- 循环直到
ptr1
越界或ptr2
越界。
- 我们设置两个指针:
该方法的时间复杂度瓶颈在于排序,为 ,后续的双指针扫描为线性的
,比
次二分查找快得多。空间复杂度为
,用于存储两组和。
代码
#include <iostream>
#include <vector>
#include <algorithm>
using namespace std;
int main() {
ios_base::sync_with_stdio(false);
cin.tie(NULL);
int n;
cin >> n;
vector<long long> a(n), b(n), c(n), d(n);
for (int i = 0; i < n; ++i) {
cin >> a[i] >> b[i] >> c[i] >> d[i];
}
vector<long long> sums_ab;
sums_ab.reserve(n * n);
for (int i = 0; i < n; ++i) {
for (int j = 0; j < n; ++j) {
sums_ab.push_back(a[i] + b[j]);
}
}
sort(sums_ab.begin(), sums_ab.end());
long long count = 0;
for (int i = 0; i < n; ++i) {
for (int j = 0; j < n; ++j) {
long long target = -(c[i] + d[j]);
auto range = equal_range(sums_ab.begin(), sums_ab.end(), target);
count += distance(range.first, range.second);
}
}
cout << count << endl;
return 0;
}
import java.io.BufferedReader;
import java.io.InputStreamReader;
import java.io.IOException;
import java.util.StringTokenizer;
import java.util.Arrays;
public class Main {
// 找到第一个 >= target 的索引
private static int lowerBound(long[] arr, long target) {
int left = 0, right = arr.length;
while (left < right) {
int mid = left + (right - left) / 2;
if (arr[mid] >= target) {
right = mid;
} else {
left = mid + 1;
}
}
return left;
}
// 找到第一个 > target 的索引
private static int upperBound(long[] arr, long target) {
int left = 0, right = arr.length;
while (left < right) {
int mid = left + (right - left) / 2;
if (arr[mid] > target) {
right = mid;
} else {
left = mid + 1;
}
}
return left;
}
public static void main(String[] args) throws IOException {
BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
int n = Integer.parseInt(br.readLine());
long[] a = new long[n];
long[] b = new long[n];
long[] c = new long[n];
long[] d = new long[n];
for (int i = 0; i < n; i++) {
StringTokenizer st = new StringTokenizer(br.readLine());
a[i] = Long.parseLong(st.nextToken());
b[i] = Long.parseLong(st.nextToken());
c[i] = Long.parseLong(st.nextToken());
d[i] = Long.parseLong(st.nextToken());
}
long[] sums_ab = new long[n * n];
int idx = 0;
for (int i = 0; i < n; i++) {
for (int j = 0; j < n; j++) {
sums_ab[idx++] = a[i] + b[j];
}
}
Arrays.sort(sums_ab);
long count = 0;
for (int i = 0; i < n; i++) {
for (int j = 0; j < n; j++) {
long target = -(c[i] + d[j]);
int upper = upperBound(sums_ab, target);
int lower = lowerBound(sums_ab, target);
count += (upper - lower);
}
}
System.out.println(count);
}
}
import sys
def main():
try:
n_str = sys.stdin.readline()
if not n_str:
return
n = int(n_str)
a, b, c, d = [0]*n, [0]*n, [0]*n, [0]*n
for i in range(n):
a[i], b[i], c[i], d[i] = map(int, sys.stdin.readline().split())
except (IOError, ValueError):
return
sums_ab = []
for val_a in a:
for val_b in b:
sums_ab.append(val_a + val_b)
sums_cd = []
for val_c in c:
for val_d in d:
sums_cd.append(val_c + val_d)
sums_ab.sort()
sums_cd.sort()
count = 0
ptr1, ptr2 = 0, len(sums_cd) - 1
len_ab = len(sums_ab)
while ptr1 < len_ab and ptr2 >= 0:
current_sum = sums_ab[ptr1] + sums_cd[ptr2]
if current_sum == 0:
val1 = sums_ab[ptr1]
count1 = 0
while ptr1 < len_ab and sums_ab[ptr1] == val1:
count1 += 1
ptr1 += 1
val2 = sums_cd[ptr2]
count2 = 0
while ptr2 >= 0 and sums_cd[ptr2] == val2:
count2 += 1
ptr2 -= 1
count += count1 * count2
elif current_sum < 0:
ptr1 += 1
else:
ptr2 -= 1
print(count)
if __name__ == "__main__":
main()
算法及复杂度
- 算法:折半枚举(Meet-in-the-Middle)+ 双指针。
- 时间复杂度:
。生成两组和需要
,对它们排序需要
。最后的双指针扫描是线性的,需要
。
- 空间复杂度:
,用于存储两组
的和。