题目链接
题目描述
给定两个长度为 N 的整型数组 A 和 B。如果 A[i] == B[j],则 (i, j) 可构成一个配对。
一个“最佳配对集合”是指一系列配对,其中 A 和 B 的每个索引最多只出现一次。
你的任务是修改 B 中的一个元素,使得“最佳配对集合”的规模(即配对数量)最大。输出这个最大数量。
解题思路
这个问题的核心在于分析单次修改操作对总配对数的影响。我们可以通过频率统计和贪心分析,在接近线性的时间内找到最优解。
1. 计算初始配对数
首先,如果不做任何修改,能形成的最大配对数是多少?
对于任意一个数值 v,它在 A 中出现了 countA[v] 次,在 B 中出现了 countB[v] 次。由于每个索引只能用一次,数值 v 最多能贡献 min(countA[v], countB[v]) 个配对。
因此,初始的最大配对数 M 等于所有数值 v 的 min(countA[v], countB[v]) 之和。这可以通过哈希表在 时间内计算出来。
2. 分析修改操作的收益
我们必须修改 B 中的一个元素,比如将 B[j] 的值从 v_old 改为 v_new。这一操作对总配对数的影响可以分解为“损失”和“增益”两部分。
-
损失 (Loss):
B中v_old的数量减一。这是否会减少总配对数?- 如果原本
B中v_old的数量就超过了A中的数量(即countB[v_old] > countA[v_old]),说明B中有“多余的”v_old,减少一个并不会影响min(countA, countB)的值。没有损失。 - 如果原本
B中v_old的数量不多于A中的数量(即countB[v_old] <= countA[v_old]),那么减少一个就会使min的值减一。损失1个配对。 - 我们可以定义一个
loss变量,当countA[v_old] >= countB[v_old]时为1,否则为0。
- 如果原本
-
增益 (Gain):
B中v_new的数量加一。这是否会增加总配对数?- 如果原本
A中v_new的数量多于B中的数量(即countA[v_new] > countB[v_new]),说明A中有“待匹配的”v_new,此时B中增加一个正好可以与之配对。增加1个配对。 - 如果原本
A中v_new的数量不多于B中的数量(即countA[v_new] <= countB[v_new]),B再增加一个v_new也找不到A中新的元素来匹配。没有增益。 - 我们可以定义一个
gain变量,当countA[v_new] > countB[v_new]时为1,否则为0。
- 如果原本
3. 最优策略
修改后的总配对数 M' = M - loss + gain。为了最大化 M',我们需要最大化 gain - loss 的值。
gain 和 loss 的取值都是 0 或 1。gain - loss 的可能取值为 1, 0, -1。我们可以分别寻找最优的 gain 和最优的 loss:
-
最大化增益 (
max_gain):我们希望gain为 1。这需要我们选择的新值v_new满足countA[v_new] > countB[v_new]。我们只需检查是否存在任何一个这样的v_new即可。如果存在,max_gain = 1;否则max_gain = 0。 -
最小化损失 (
min_loss):我们希望loss为 0。这需要我们选择被修改的B[j](其值为v_old)满足countB[v_old] > countA[v_old]。我们只需检查B数组中是否存在任何一个这样的元素即可。如果存在,min_loss = 0;否则,无论修改哪个元素都会造成损失,min_loss = 1。
最终,能达到的最大配对数就是 M + max_gain - min_loss。
算法步骤总结
- 使用哈希表统计
A和B中每个数字的出现次数,得到countA和countB。 - 计算初始配对数
M = sum(min(countA[v], countB[v]))。 - 初始化
max_gain = 0。遍历countA,如果发现任何v使得countA[v] > countB.get(v, 0),则将max_gain设为 1 并终止遍历。 - 初始化
min_loss = 1。遍历数组B,如果发现任何元素b使得countB[b] > countA.get(b, 0),则将min_loss设为 0 并终止遍历。 - 最终答案为
M + max_gain - min_loss。
代码
#include <iostream>
#include <vector>
#include <unordered_map>
#include <algorithm>
using namespace std;
int main() {
ios_base::sync_with_stdio(false);
cin.tie(NULL);
int n;
cin >> n;
vector<int> a(n), b(n);
unordered_map<int, int> count_a, count_b;
for (int i = 0; i < n; ++i) {
cin >> a[i];
count_a[a[i]]++;
}
for (int i = 0; i < n; ++i) {
cin >> b[i];
count_b[b[i]]++;
}
long long initial_pairs = 0;
for (auto const& [val, count] : count_a) {
if (count_b.count(val)) {
initial_pairs += min(count, count_b[val]);
}
}
int max_gain = 0;
for (auto const& [val, count] : count_a) {
int b_count = count_b.count(val) ? count_b[val] : 0;
if (count > b_count) {
max_gain = 1;
break;
}
}
int min_loss = 1;
for (int val : b) {
int a_count = count_a.count(val) ? count_a[val] : 0;
if (count_b[val] > a_count) {
min_loss = 0;
break;
}
}
cout << initial_pairs + max_gain - min_loss << endl;
return 0;
}
import java.util.HashMap;
import java.util.Map;
import java.util.Scanner;
public class Main {
public static void main(String[] args) {
Scanner sc = new Scanner(System.in);
int n = sc.nextInt();
int[] a = new int[n];
int[] b = new int[n];
Map<Integer, Integer> countA = new HashMap<>();
Map<Integer, Integer> countB = new HashMap<>();
for (int i = 0; i < n; i++) {
a[i] = sc.nextInt();
countA.put(a[i], countA.getOrDefault(a[i], 0) + 1);
}
for (int i = 0; i < n; i++) {
b[i] = sc.nextInt();
countB.put(b[i], countB.getOrDefault(b[i], 0) + 1);
}
long initialPairs = 0;
for (Map.Entry<Integer, Integer> entry : countA.entrySet()) {
int val = entry.getKey();
int count = entry.getValue();
initialPairs += Math.min(count, countB.getOrDefault(val, 0));
}
int maxGain = 0;
for (Map.Entry<Integer, Integer> entry : countA.entrySet()) {
int val = entry.getKey();
int count = entry.getValue();
if (count > countB.getOrDefault(val, 0)) {
maxGain = 1;
break;
}
}
int minLoss = 1;
for (int val : b) {
if (countB.get(val) > countA.getOrDefault(val, 0)) {
minLoss = 0;
break;
}
}
System.out.println(initialPairs + maxGain - minLoss);
}
}
import sys
from collections import Counter
def solve():
try:
n_str = sys.stdin.readline()
if not n_str:
return
n = int(n_str)
a = list(map(int, sys.stdin.readline().split()))
b = list(map(int, sys.stdin.readline().split()))
count_a = Counter(a)
count_b = Counter(b)
initial_pairs = 0
for val, count in count_a.items():
initial_pairs += min(count, count_b[val])
max_gain = 0
for val, count in count_a.items():
if count > count_b[val]:
max_gain = 1
break
min_loss = 1
for val in b:
if count_b[val] > count_a[val]:
min_loss = 0
break
print(initial_pairs + max_gain - min_loss)
except (IOError, ValueError):
return
solve()
算法及复杂度
-
算法:频率统计、贪心
-
时间复杂度:
,其中
N是数组长度,U_A是数组A中不同元素的数量。主要开销来自于遍历数组和哈希表。由于,总时间复杂度为
。
-
空间复杂度:
,用于存储两个哈希表,其中
和
分别是
A和B中不同元素的数量。最坏情况下为。

京公网安备 11010502036488号