题目链接

最佳配对

题目描述

给定两个长度为 N 的整型数组 AB。如果 A[i] == B[j],则 (i, j) 可构成一个配对。

一个“最佳配对集合”是指一系列配对,其中 AB 的每个索引最多只出现一次。

你的任务是修改 B 中的一个元素,使得“最佳配对集合”的规模(即配对数量)最大。输出这个最大数量。

解题思路

这个问题的核心在于分析单次修改操作对总配对数的影响。我们可以通过频率统计和贪心分析,在接近线性的时间内找到最优解。

1. 计算初始配对数

首先,如果不做任何修改,能形成的最大配对数是多少?

对于任意一个数值 v,它在 A 中出现了 countA[v] 次,在 B 中出现了 countB[v] 次。由于每个索引只能用一次,数值 v 最多能贡献 min(countA[v], countB[v]) 个配对。

因此,初始的最大配对数 M 等于所有数值 vmin(countA[v], countB[v]) 之和。这可以通过哈希表在 时间内计算出来。

2. 分析修改操作的收益

我们必须修改 B 中的一个元素,比如将 B[j] 的值从 v_old 改为 v_new。这一操作对总配对数的影响可以分解为“损失”和“增益”两部分。

  • 损失 (Loss)Bv_old 的数量减一。这是否会减少总配对数?

    • 如果原本 Bv_old 的数量就超过A 中的数量(即 countB[v_old] > countA[v_old]),说明 B 中有“多余的” v_old,减少一个并不会影响 min(countA, countB) 的值。没有损失
    • 如果原本 Bv_old 的数量不多于 A 中的数量(即 countB[v_old] <= countA[v_old]),那么减少一个就会使 min 的值减一。损失1个配对
    • 我们可以定义一个 loss 变量,当 countA[v_old] >= countB[v_old] 时为1,否则为0。
  • 增益 (Gain)Bv_new 的数量加一。这是否会增加总配对数?

    • 如果原本 Av_new 的数量多于 B 中的数量(即 countA[v_new] > countB[v_new]),说明 A 中有“待匹配的” v_new,此时 B 中增加一个正好可以与之配对。增加1个配对
    • 如果原本 Av_new 的数量不多于 B 中的数量(即 countA[v_new] <= countB[v_new]),B 再增加一个 v_new 也找不到 A 中新的元素来匹配。没有增益
    • 我们可以定义一个 gain 变量,当 countA[v_new] > countB[v_new] 时为1,否则为0。

3. 最优策略

修改后的总配对数 M' = M - loss + gain。为了最大化 M',我们需要最大化 gain - loss 的值。

gainloss 的取值都是 0 或 1。gain - loss 的可能取值为 1, 0, -1。我们可以分别寻找最优的 gain 和最优的 loss

  1. 最大化增益 (max_gain):我们希望 gain 为 1。这需要我们选择的新值 v_new 满足 countA[v_new] > countB[v_new]。我们只需检查是否存在任何一个这样的 v_new 即可。如果存在,max_gain = 1;否则 max_gain = 0

  2. 最小化损失 (min_loss):我们希望 loss 为 0。这需要我们选择被修改的 B[j](其值为 v_old)满足 countB[v_old] > countA[v_old]。我们只需检查 B 数组中是否存在任何一个这样的元素即可。如果存在,min_loss = 0;否则,无论修改哪个元素都会造成损失,min_loss = 1

最终,能达到的最大配对数就是 M + max_gain - min_loss

算法步骤总结

  1. 使用哈希表统计 AB 中每个数字的出现次数,得到 countAcountB
  2. 计算初始配对数 M = sum(min(countA[v], countB[v]))
  3. 初始化 max_gain = 0。遍历 countA,如果发现任何 v 使得 countA[v] > countB.get(v, 0),则将 max_gain 设为 1 并终止遍历。
  4. 初始化 min_loss = 1。遍历数组 B,如果发现任何元素 b 使得 countB[b] > countA.get(b, 0),则将 min_loss 设为 0 并终止遍历。
  5. 最终答案为 M + max_gain - min_loss

代码

#include <iostream>
#include <vector>
#include <unordered_map>
#include <algorithm>

using namespace std;

int main() {
    ios_base::sync_with_stdio(false);
    cin.tie(NULL);

    int n;
    cin >> n;

    vector<int> a(n), b(n);
    unordered_map<int, int> count_a, count_b;

    for (int i = 0; i < n; ++i) {
        cin >> a[i];
        count_a[a[i]]++;
    }
    for (int i = 0; i < n; ++i) {
        cin >> b[i];
        count_b[b[i]]++;
    }

    long long initial_pairs = 0;
    for (auto const& [val, count] : count_a) {
        if (count_b.count(val)) {
            initial_pairs += min(count, count_b[val]);
        }
    }

    int max_gain = 0;
    for (auto const& [val, count] : count_a) {
        int b_count = count_b.count(val) ? count_b[val] : 0;
        if (count > b_count) {
            max_gain = 1;
            break;
        }
    }

    int min_loss = 1;
    for (int val : b) {
        int a_count = count_a.count(val) ? count_a[val] : 0;
        if (count_b[val] > a_count) {
            min_loss = 0;
            break;
        }
    }

    cout << initial_pairs + max_gain - min_loss << endl;

    return 0;
}
import java.util.HashMap;
import java.util.Map;
import java.util.Scanner;

public class Main {
    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);
        int n = sc.nextInt();

        int[] a = new int[n];
        int[] b = new int[n];
        Map<Integer, Integer> countA = new HashMap<>();
        Map<Integer, Integer> countB = new HashMap<>();

        for (int i = 0; i < n; i++) {
            a[i] = sc.nextInt();
            countA.put(a[i], countA.getOrDefault(a[i], 0) + 1);
        }
        for (int i = 0; i < n; i++) {
            b[i] = sc.nextInt();
            countB.put(b[i], countB.getOrDefault(b[i], 0) + 1);
        }

        long initialPairs = 0;
        for (Map.Entry<Integer, Integer> entry : countA.entrySet()) {
            int val = entry.getKey();
            int count = entry.getValue();
            initialPairs += Math.min(count, countB.getOrDefault(val, 0));
        }

        int maxGain = 0;
        for (Map.Entry<Integer, Integer> entry : countA.entrySet()) {
            int val = entry.getKey();
            int count = entry.getValue();
            if (count > countB.getOrDefault(val, 0)) {
                maxGain = 1;
                break;
            }
        }

        int minLoss = 1;
        for (int val : b) {
            if (countB.get(val) > countA.getOrDefault(val, 0)) {
                minLoss = 0;
                break;
            }
        }

        System.out.println(initialPairs + maxGain - minLoss);
    }
}
import sys
from collections import Counter

def solve():
    try:
        n_str = sys.stdin.readline()
        if not n_str:
            return
        n = int(n_str)
        
        a = list(map(int, sys.stdin.readline().split()))
        b = list(map(int, sys.stdin.readline().split()))
        
        count_a = Counter(a)
        count_b = Counter(b)
        
        initial_pairs = 0
        for val, count in count_a.items():
            initial_pairs += min(count, count_b[val])
            
        max_gain = 0
        for val, count in count_a.items():
            if count > count_b[val]:
                max_gain = 1
                break
                
        min_loss = 1
        for val in b:
            if count_b[val] > count_a[val]:
                min_loss = 0
                break
                
        print(initial_pairs + max_gain - min_loss)

    except (IOError, ValueError):
        return

solve()

算法及复杂度

  • 算法:频率统计、贪心

  • 时间复杂度: ,其中 N 是数组长度,U_A 是数组 A 中不同元素的数量。主要开销来自于遍历数组和哈希表。由于 ,总时间复杂度为

  • 空间复杂度: ,用于存储两个哈希表,其中 分别是 AB 中不同元素的数量。最坏情况下为