题目链接

四值零和

题目描述

给定一个 的整数矩阵。要求从四列中各选一个数,使得这四个数的和为 0。请计算共有多少种不同的选法组合。 形式化地,给定四个长度为 的数组 ,找出满足 的四元组 的数量,其中 的范围均为

解题思路

一个朴素的解法是使用四重循环遍历所有可能的下标组合 ,检查它们的和是否为 0。这种方法的时间复杂度为 ,会严重超时。

为了优化算法,我们可以对原方程 进行变形,得到

这个形式启发我们使用折半枚举(Meet-in-the-Middle) 的思想。之前我们尝试了“排序+二分查找”(超时)和“哈希表”(内存超限)的方法。为了在时间和空间上取得平衡,我们采用第三种经典方法:折半枚举 + 双指针

  1. 预处理:我们首先用两重循环计算出两组和:
    • sums_ab:包含所有 的和。
    • sums_cd:包含所有 的和。
  2. 排序:我们将 sums_ab升序排序,并将 sums_cd 也按升序排序。
  3. 双指针扫描
    • 我们设置两个指针:ptr1 指向 sums_ab 的开头(最小值),ptr2 指向 sums_cd末尾(最大值)。
    • 我们移动指针,计算 current_sum = sums_ab[ptr1] + sums_cd[ptr2]
      • current_sum == 0:我们找到了一个解。由于数组中可能存在重复元素,我们需要统计当前 sums_ab[ptr1]sums_cd[ptr2] 各自连续出现的次数(count1count2)。然后将 count1 * count2 累加到总答案中,并把两个指针分别移动到下一个不重复的元素上。
      • current_sum < 0:和太小,需要增大 sums_ab 的值,因此 ptr1 右移。
      • current_sum > 0:和太大,需要减小 sums_cd 的值,因此 ptr2 左移。
    • 循环直到 ptr1 越界或 ptr2 越界。

该方法的时间复杂度瓶颈在于排序,为 ,后续的双指针扫描为线性的 ,比 次二分查找快得多。空间复杂度为 ,用于存储两组和。

代码

#include <iostream>
#include <vector>
#include <algorithm>

using namespace std;

int main() {
    ios_base::sync_with_stdio(false);
    cin.tie(NULL);

    int n;
    cin >> n;
    vector<long long> a(n), b(n), c(n), d(n);
    for (int i = 0; i < n; ++i) {
        cin >> a[i] >> b[i] >> c[i] >> d[i];
    }

    vector<long long> sums_ab;
    sums_ab.reserve(n * n);
    for (int i = 0; i < n; ++i) {
        for (int j = 0; j < n; ++j) {
            sums_ab.push_back(a[i] + b[j]);
        }
    }

    sort(sums_ab.begin(), sums_ab.end());

    long long count = 0;
    for (int i = 0; i < n; ++i) {
        for (int j = 0; j < n; ++j) {
            long long target = -(c[i] + d[j]);
            auto range = equal_range(sums_ab.begin(), sums_ab.end(), target);
            count += distance(range.first, range.second);
        }
    }

    cout << count << endl;

    return 0;
}
import java.io.BufferedReader;
import java.io.InputStreamReader;
import java.io.IOException;
import java.util.StringTokenizer;
import java.util.Arrays;

public class Main {
    // 找到第一个 >= target 的索引
    private static int lowerBound(long[] arr, long target) {
        int left = 0, right = arr.length;
        while (left < right) {
            int mid = left + (right - left) / 2;
            if (arr[mid] >= target) {
                right = mid;
            } else {
                left = mid + 1;
            }
        }
        return left;
    }
    
    // 找到第一个 > target 的索引
    private static int upperBound(long[] arr, long target) {
        int left = 0, right = arr.length;
        while (left < right) {
            int mid = left + (right - left) / 2;
            if (arr[mid] > target) {
                right = mid;
            } else {
                left = mid + 1;
            }
        }
        return left;
    }

    public static void main(String[] args) throws IOException {
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        int n = Integer.parseInt(br.readLine());
        long[] a = new long[n];
        long[] b = new long[n];
        long[] c = new long[n];
        long[] d = new long[n];

        for (int i = 0; i < n; i++) {
            StringTokenizer st = new StringTokenizer(br.readLine());
            a[i] = Long.parseLong(st.nextToken());
            b[i] = Long.parseLong(st.nextToken());
            c[i] = Long.parseLong(st.nextToken());
            d[i] = Long.parseLong(st.nextToken());
        }

        long[] sums_ab = new long[n * n];
        int idx = 0;
        for (int i = 0; i < n; i++) {
            for (int j = 0; j < n; j++) {
                sums_ab[idx++] = a[i] + b[j];
            }
        }

        Arrays.sort(sums_ab);

        long count = 0;
        for (int i = 0; i < n; i++) {
            for (int j = 0; j < n; j++) {
                long target = -(c[i] + d[j]);
                int upper = upperBound(sums_ab, target);
                int lower = lowerBound(sums_ab, target);
                count += (upper - lower);
            }
        }

        System.out.println(count);
    }
}
import sys

def main():
    try:
        n_str = sys.stdin.readline()
        if not n_str:
            return
        n = int(n_str)
        a, b, c, d = [0]*n, [0]*n, [0]*n, [0]*n
        for i in range(n):
            a[i], b[i], c[i], d[i] = map(int, sys.stdin.readline().split())
    except (IOError, ValueError):
        return

    sums_ab = []
    for val_a in a:
        for val_b in b:
            sums_ab.append(val_a + val_b)
    
    sums_cd = []
    for val_c in c:
        for val_d in d:
            sums_cd.append(val_c + val_d)

    sums_ab.sort()
    sums_cd.sort()

    count = 0
    ptr1, ptr2 = 0, len(sums_cd) - 1
    len_ab = len(sums_ab)

    while ptr1 < len_ab and ptr2 >= 0:
        current_sum = sums_ab[ptr1] + sums_cd[ptr2]
        if current_sum == 0:
            val1 = sums_ab[ptr1]
            count1 = 0
            while ptr1 < len_ab and sums_ab[ptr1] == val1:
                count1 += 1
                ptr1 += 1
            
            val2 = sums_cd[ptr2]
            count2 = 0
            while ptr2 >= 0 and sums_cd[ptr2] == val2:
                count2 += 1
                ptr2 -= 1
            
            count += count1 * count2
        elif current_sum < 0:
            ptr1 += 1
        else:
            ptr2 -= 1
            
    print(count)

if __name__ == "__main__":
    main()

算法及复杂度

  • 算法:折半枚举(Meet-in-the-Middle)+ 双指针。
  • 时间复杂度。生成两组和需要 ,对它们排序需要 。最后的双指针扫描是线性的,需要
  • 空间复杂度,用于存储两组 的和。