题目链接
题目描述
给定两个长度为 N
的整型数组 A
和 B
。如果 A[i] == B[j]
,则 (i, j)
可构成一个配对。
一个“最佳配对集合”是指一系列配对,其中 A
和 B
的每个索引最多只出现一次。
你的任务是修改 B
中的一个元素,使得“最佳配对集合”的规模(即配对数量)最大。输出这个最大数量。
解题思路
这个问题的核心在于分析单次修改操作对总配对数的影响。我们可以通过频率统计和贪心分析,在接近线性的时间内找到最优解。
1. 计算初始配对数
首先,如果不做任何修改,能形成的最大配对数是多少?
对于任意一个数值 v
,它在 A
中出现了 countA[v]
次,在 B
中出现了 countB[v]
次。由于每个索引只能用一次,数值 v
最多能贡献 min(countA[v], countB[v])
个配对。
因此,初始的最大配对数 M
等于所有数值 v
的 min(countA[v], countB[v])
之和。这可以通过哈希表在 时间内计算出来。
2. 分析修改操作的收益
我们必须修改 B
中的一个元素,比如将 B[j]
的值从 v_old
改为 v_new
。这一操作对总配对数的影响可以分解为“损失”和“增益”两部分。
-
损失 (Loss):
B
中v_old
的数量减一。这是否会减少总配对数?- 如果原本
B
中v_old
的数量就超过了A
中的数量(即countB[v_old] > countA[v_old]
),说明B
中有“多余的”v_old
,减少一个并不会影响min(countA, countB)
的值。没有损失。 - 如果原本
B
中v_old
的数量不多于A
中的数量(即countB[v_old] <= countA[v_old]
),那么减少一个就会使min
的值减一。损失1个配对。 - 我们可以定义一个
loss
变量,当countA[v_old] >= countB[v_old]
时为1,否则为0。
- 如果原本
-
增益 (Gain):
B
中v_new
的数量加一。这是否会增加总配对数?- 如果原本
A
中v_new
的数量多于B
中的数量(即countA[v_new] > countB[v_new]
),说明A
中有“待匹配的”v_new
,此时B
中增加一个正好可以与之配对。增加1个配对。 - 如果原本
A
中v_new
的数量不多于B
中的数量(即countA[v_new] <= countB[v_new]
),B
再增加一个v_new
也找不到A
中新的元素来匹配。没有增益。 - 我们可以定义一个
gain
变量,当countA[v_new] > countB[v_new]
时为1,否则为0。
- 如果原本
3. 最优策略
修改后的总配对数 M' = M - loss + gain
。为了最大化 M'
,我们需要最大化 gain - loss
的值。
gain
和 loss
的取值都是 0 或 1。gain - loss
的可能取值为 1, 0, -1
。我们可以分别寻找最优的 gain
和最优的 loss
:
-
最大化增益 (
max_gain
):我们希望gain
为 1。这需要我们选择的新值v_new
满足countA[v_new] > countB[v_new]
。我们只需检查是否存在任何一个这样的v_new
即可。如果存在,max_gain = 1
;否则max_gain = 0
。 -
最小化损失 (
min_loss
):我们希望loss
为 0。这需要我们选择被修改的B[j]
(其值为v_old
)满足countB[v_old] > countA[v_old]
。我们只需检查B
数组中是否存在任何一个这样的元素即可。如果存在,min_loss = 0
;否则,无论修改哪个元素都会造成损失,min_loss = 1
。
最终,能达到的最大配对数就是 M + max_gain - min_loss
。
算法步骤总结
- 使用哈希表统计
A
和B
中每个数字的出现次数,得到countA
和countB
。 - 计算初始配对数
M = sum(min(countA[v], countB[v]))
。 - 初始化
max_gain = 0
。遍历countA
,如果发现任何v
使得countA[v] > countB.get(v, 0)
,则将max_gain
设为 1 并终止遍历。 - 初始化
min_loss = 1
。遍历数组B
,如果发现任何元素b
使得countB[b] > countA.get(b, 0)
,则将min_loss
设为 0 并终止遍历。 - 最终答案为
M + max_gain - min_loss
。
代码
#include <iostream>
#include <vector>
#include <unordered_map>
#include <algorithm>
using namespace std;
int main() {
ios_base::sync_with_stdio(false);
cin.tie(NULL);
int n;
cin >> n;
vector<int> a(n), b(n);
unordered_map<int, int> count_a, count_b;
for (int i = 0; i < n; ++i) {
cin >> a[i];
count_a[a[i]]++;
}
for (int i = 0; i < n; ++i) {
cin >> b[i];
count_b[b[i]]++;
}
long long initial_pairs = 0;
for (auto const& [val, count] : count_a) {
if (count_b.count(val)) {
initial_pairs += min(count, count_b[val]);
}
}
int max_gain = 0;
for (auto const& [val, count] : count_a) {
int b_count = count_b.count(val) ? count_b[val] : 0;
if (count > b_count) {
max_gain = 1;
break;
}
}
int min_loss = 1;
for (int val : b) {
int a_count = count_a.count(val) ? count_a[val] : 0;
if (count_b[val] > a_count) {
min_loss = 0;
break;
}
}
cout << initial_pairs + max_gain - min_loss << endl;
return 0;
}
import java.util.HashMap;
import java.util.Map;
import java.util.Scanner;
public class Main {
public static void main(String[] args) {
Scanner sc = new Scanner(System.in);
int n = sc.nextInt();
int[] a = new int[n];
int[] b = new int[n];
Map<Integer, Integer> countA = new HashMap<>();
Map<Integer, Integer> countB = new HashMap<>();
for (int i = 0; i < n; i++) {
a[i] = sc.nextInt();
countA.put(a[i], countA.getOrDefault(a[i], 0) + 1);
}
for (int i = 0; i < n; i++) {
b[i] = sc.nextInt();
countB.put(b[i], countB.getOrDefault(b[i], 0) + 1);
}
long initialPairs = 0;
for (Map.Entry<Integer, Integer> entry : countA.entrySet()) {
int val = entry.getKey();
int count = entry.getValue();
initialPairs += Math.min(count, countB.getOrDefault(val, 0));
}
int maxGain = 0;
for (Map.Entry<Integer, Integer> entry : countA.entrySet()) {
int val = entry.getKey();
int count = entry.getValue();
if (count > countB.getOrDefault(val, 0)) {
maxGain = 1;
break;
}
}
int minLoss = 1;
for (int val : b) {
if (countB.get(val) > countA.getOrDefault(val, 0)) {
minLoss = 0;
break;
}
}
System.out.println(initialPairs + maxGain - minLoss);
}
}
import sys
from collections import Counter
def solve():
try:
n_str = sys.stdin.readline()
if not n_str:
return
n = int(n_str)
a = list(map(int, sys.stdin.readline().split()))
b = list(map(int, sys.stdin.readline().split()))
count_a = Counter(a)
count_b = Counter(b)
initial_pairs = 0
for val, count in count_a.items():
initial_pairs += min(count, count_b[val])
max_gain = 0
for val, count in count_a.items():
if count > count_b[val]:
max_gain = 1
break
min_loss = 1
for val in b:
if count_b[val] > count_a[val]:
min_loss = 0
break
print(initial_pairs + max_gain - min_loss)
except (IOError, ValueError):
return
solve()
算法及复杂度
-
算法:频率统计、贪心
-
时间复杂度:
,其中
N
是数组长度,U_A
是数组A
中不同元素的数量。主要开销来自于遍历数组和哈希表。由于,总时间复杂度为
。
-
空间复杂度:
,用于存储两个哈希表,其中
和
分别是
A
和B
中不同元素的数量。最坏情况下为。