题目链接
题目描述
n
个点连成一条链,编号从 1
到 n
。第 i
条边连接点 i
和 i+1
。
每个点 i
有一个点权 a[i]
。
第 i
条边的权重 w[i]
定义为:满足 x <= i
且 y > i
,并且 a[x] != a[y]
的点对 (x, y)
的数量。
给定所有点的点权,请求出所有 n-1
条边的权重。
解题思路
这是一个计数问题。对于每一条边 i
,我们需要计算它左右两侧点权不同的点对数量。
1. 核心思想:正难则反
直接计算“不同”的点对数量比较复杂。我们可以换一个思路:
w[i] = (总点对数) - (点权相同的点对数)
- 总点对数:对于第
i
条边,其左侧有i
个点(从1到i),右侧有n-i
个点(从i+1到n)。所以,跨越这条边的总点对数量是i * (n-i)
。 - 点权相同的点对数:我们需要计算满足
x <= i
,y > i
, 且a[x] == a[y]
的点对数量。
2. 高效计算相同点对: 的暴力解法
暴力计算每一条边的相同点对数,总时间复杂度会达到 。
对于第
i
条边,我们可以:
- 建立一个哈希表
left_counts
统计a[0...i]
中每个点权的频率。 - 建立另一个哈希表
right_counts
统计a[i+1...n-1]
中每个点权的频率。 - 遍历
left_counts
,对于每个点权v
,相同点对数 += left_counts[v] * right_counts[v]
。
对每条边都重复这个过程,总复杂度就是 ,在
N=100000
时会超时。
3. 算法优化:寻找递推关系
我们可以发现,当分割点从 i
移动到 i+1
时,只有点 a[i+1]
从右侧集合移动到了左侧集合。这意味着 SamePairs(i+1)
和 SamePairs(i)
之间存在密切的联系。我们可以利用这个关系进行快速计算,而无需每次都重新统计。
设 sp(i)
为以第 i
条边为分割的相同点对数。
sp(i) = Σ (count_left(v, i) * count_right(v, i))
(对所有点权 v
求和)
当我们从边 i-1
移动到边 i
时,点 a[i]
从右侧移动到左侧。设 v_new = a[i]
。
对于 v_new
,它在左侧的数量增加了1,在右侧的数量减少了1。
对于其他 v != v_new
,其在左右两侧的数量不变。
经过推导,我们可以得到 sp(i)
和 sp(i-1)
的递推关系:
sp(i) = sp(i-1) + count_right(v_new, i-1) - count_left(v_new, i-1) - 1
其中 count_right(v_new, i-1)
是 v_new
在 {a[i], ..., a[n-1]}
中的数量,count_left(v_new, i-1)
是 v_new
在 {a[0], ..., a[i-1]}
中的数量。
利用 count_right + count_left = total_count
,公式可以进一步化为:
sp(i) = sp(i-1) + total_count(v_new) - 2 * count_left(v_new, i-1) - 1
这个递推关系允许我们仅用上一步的结果和一些计数值,在 (哈希表) 或
(
map
) 的时间内计算出当前步的结果。
算法步骤总结 (优化后)
- 用哈希表
total_counts
统计数组a
中所有数字的频率。 - 初始化一个空的哈希表
left_counts
,用于记录当前分割点左侧的数字频率。 - 初始化
total_same_pairs = 0
。 - 循环
i
从0
到n-2
(对应第0
到n-2
条边):v_new = a[i]
。count_in_prev_left = left_counts.get(v_new, 0)
。- 使用递推公式更新
total_same_pairs
:- 如果是第一次(
i=0
),total_same_pairs = 1 * (total_counts[v_new] - 1)
。 - 否则,
total_same_pairs = total_same_pairs + total_counts[v_new] - 2 * count_in_prev_left - 1
。
- 如果是第一次(
- 更新
left_counts
:left_counts[v_new]++
。 - 计算权重
weights[i]
:left_size = i + 1
。right_size = n - left_size
。weights[i] = left_size * right_size - total_same_pairs
。
- 输出
weights
数组。
代码
#include <iostream>
#include <vector>
#include <map>
using namespace std;
int main() {
ios_base::sync_with_stdio(false);
cin.tie(NULL);
int n;
cin >> n;
vector<int> a(n);
map<int, int> total_counts;
for (int i = 0; i < n; ++i) {
cin >> a[i];
total_counts[a[i]]++;
}
map<int, int> left_counts;
vector<long long> weights(n - 1);
long long total_same_pairs = 0;
for (int i = 0; i < n - 1; ++i) {
int v_new = a[i];
long long count_in_prev_left = 0;
if (left_counts.count(v_new)) {
count_in_prev_left = left_counts[v_new];
}
// sp(i) = sp(i-1) + total_counts(v) - 2*left_counts(v, i-1) - 1
// Note: this recurrence applies for i > 0. For i = 0, we compute directly.
// For simplicity, we can think of sp(-1) = 0, left_counts(-1) is empty.
// The change is from point a[i] moving from right to left.
// It forms 'count_in_prev_left' new same-pairs with left side, and loses 'count_in_right' same-pairs with right side.
// A bit tricky. Let's use the algebraic update.
if (i > 0) {
total_same_pairs = total_same_pairs + total_counts[v_new] - 2 * count_in_prev_left - 1;
} else { // i == 0
total_same_pairs = 1LL * (total_counts[v_new] - 1);
}
left_counts[v_new]++;
long long left_size = i + 1;
long long right_size = n - left_size;
weights[i] = left_size * right_size - total_same_pairs;
}
for (int i = 0; i < n - 1; ++i) {
cout << weights[i] << (i == n - 2 ? "" : " ");
}
cout << endl;
return 0;
}
import java.util.HashMap;
import java.util.Map;
import java.util.Scanner;
public class Main {
public static void main(String[] args) {
Scanner sc = new Scanner(System.in);
int n = sc.nextInt();
int[] a = new int[n];
Map<Integer, Integer> totalCounts = new HashMap<>();
for (int i = 0; i < n; i++) {
a[i] = sc.nextInt();
totalCounts.put(a[i], totalCounts.getOrDefault(a[i], 0) + 1);
}
Map<Integer, Integer> leftCounts = new HashMap<>();
long[] weights = new long[n - 1];
long totalSamePairs = 0;
for (int i = 0; i < n - 1; i++) {
int vNew = a[i];
long countInPrevLeft = leftCounts.getOrDefault(vNew, 0);
if (i > 0) {
totalSamePairs = totalSamePairs + totalCounts.get(vNew) - 2 * countInPrevLeft - 1;
} else { // i == 0
totalSamePairs = (long)totalCounts.get(vNew) - 1;
}
leftCounts.put(vNew, leftCounts.getOrDefault(vNew, 0) + 1);
long leftSize = i + 1;
long rightSize = n - leftSize;
weights[i] = leftSize * rightSize - totalSamePairs;
}
for (int i = 0; i < n - 1; i++) {
System.out.print(weights[i] + (i == n - 2 ? "" : " "));
}
System.out.println();
}
}
import sys
from collections import Counter
def solve():
try:
n_str = sys.stdin.readline()
if not n_str:
return
n = int(n_str)
a = list(map(int, sys.stdin.readline().split()))
total_counts = Counter(a)
left_counts = Counter()
weights = []
total_same_pairs = 0
for i in range(n - 1):
v_new = a[i]
count_in_prev_left = left_counts[v_new]
if i > 0:
total_same_pairs = total_same_pairs + total_counts[v_new] - 2 * count_in_prev_left - 1
else: # i == 0
total_same_pairs = total_counts[v_new] - 1
left_counts[v_new] += 1
left_size = i + 1
right_size = n - left_size
weight = left_size * right_size - total_same_pairs
weights.append(weight)
print(*weights)
except (IOError, ValueError):
return
solve()
算法及复杂度
-
算法:计数、哈希表、动态规划(递推)
-
时间复杂度:
或
。预处理统计总频率需要
(使用
map
) 或(使用
unordered_map
)。之后,计算n-1
个边权,每一步的递推和更新哈希表需要或
的时间。所以总时间复杂度由哈希表的实现决定。
-
空间复杂度:
,其中
U
是整个数组中不同点权的数量。用于存储频率哈希表。