题目链接
题目描述
n 个点连成一条链,编号从 1 到 n。第 i 条边连接点 i 和 i+1。
每个点 i 有一个点权 a[i]。
第 i 条边的权重 w[i] 定义为:满足 x <= i 且 y > i,并且 a[x] != a[y] 的点对 (x, y) 的数量。
给定所有点的点权,请求出所有 n-1 条边的权重。
解题思路
这是一个计数问题。对于每一条边 i,我们需要计算它左右两侧点权不同的点对数量。
1. 核心思想:正难则反
直接计算“不同”的点对数量比较复杂。我们可以换一个思路:
w[i] = (总点对数) - (点权相同的点对数)
- 总点对数:对于第
i条边,其左侧有i个点(从1到i),右侧有n-i个点(从i+1到n)。所以,跨越这条边的总点对数量是i * (n-i)。 - 点权相同的点对数:我们需要计算满足
x <= i,y > i, 且a[x] == a[y]的点对数量。
2. 高效计算相同点对: 的暴力解法
暴力计算每一条边的相同点对数,总时间复杂度会达到 。
对于第
i 条边,我们可以:
- 建立一个哈希表
left_counts统计a[0...i]中每个点权的频率。 - 建立另一个哈希表
right_counts统计a[i+1...n-1]中每个点权的频率。 - 遍历
left_counts,对于每个点权v,相同点对数 += left_counts[v] * right_counts[v]。
对每条边都重复这个过程,总复杂度就是 ,在
N=100000 时会超时。
3. 算法优化:寻找递推关系
我们可以发现,当分割点从 i 移动到 i+1 时,只有点 a[i+1] 从右侧集合移动到了左侧集合。这意味着 SamePairs(i+1) 和 SamePairs(i) 之间存在密切的联系。我们可以利用这个关系进行快速计算,而无需每次都重新统计。
设 sp(i) 为以第 i 条边为分割的相同点对数。
sp(i) = Σ (count_left(v, i) * count_right(v, i)) (对所有点权 v 求和)
当我们从边 i-1 移动到边 i 时,点 a[i] 从右侧移动到左侧。设 v_new = a[i]。
对于 v_new,它在左侧的数量增加了1,在右侧的数量减少了1。
对于其他 v != v_new,其在左右两侧的数量不变。
经过推导,我们可以得到 sp(i) 和 sp(i-1) 的递推关系:
sp(i) = sp(i-1) + count_right(v_new, i-1) - count_left(v_new, i-1) - 1
其中 count_right(v_new, i-1) 是 v_new 在 {a[i], ..., a[n-1]} 中的数量,count_left(v_new, i-1) 是 v_new 在 {a[0], ..., a[i-1]} 中的数量。
利用 count_right + count_left = total_count,公式可以进一步化为:
sp(i) = sp(i-1) + total_count(v_new) - 2 * count_left(v_new, i-1) - 1
这个递推关系允许我们仅用上一步的结果和一些计数值,在 (哈希表) 或
(
map) 的时间内计算出当前步的结果。
算法步骤总结 (优化后)
- 用哈希表
total_counts统计数组a中所有数字的频率。 - 初始化一个空的哈希表
left_counts,用于记录当前分割点左侧的数字频率。 - 初始化
total_same_pairs = 0。 - 循环
i从0到n-2(对应第0到n-2条边):v_new = a[i]。count_in_prev_left = left_counts.get(v_new, 0)。- 使用递推公式更新
total_same_pairs:- 如果是第一次(
i=0),total_same_pairs = 1 * (total_counts[v_new] - 1)。 - 否则,
total_same_pairs = total_same_pairs + total_counts[v_new] - 2 * count_in_prev_left - 1。
- 如果是第一次(
- 更新
left_counts:left_counts[v_new]++。 - 计算权重
weights[i]:left_size = i + 1。right_size = n - left_size。weights[i] = left_size * right_size - total_same_pairs。
- 输出
weights数组。
代码
#include <iostream>
#include <vector>
#include <map>
using namespace std;
int main() {
ios_base::sync_with_stdio(false);
cin.tie(NULL);
int n;
cin >> n;
vector<int> a(n);
map<int, int> total_counts;
for (int i = 0; i < n; ++i) {
cin >> a[i];
total_counts[a[i]]++;
}
map<int, int> left_counts;
vector<long long> weights(n - 1);
long long total_same_pairs = 0;
for (int i = 0; i < n - 1; ++i) {
int v_new = a[i];
long long count_in_prev_left = 0;
if (left_counts.count(v_new)) {
count_in_prev_left = left_counts[v_new];
}
// sp(i) = sp(i-1) + total_counts(v) - 2*left_counts(v, i-1) - 1
// Note: this recurrence applies for i > 0. For i = 0, we compute directly.
// For simplicity, we can think of sp(-1) = 0, left_counts(-1) is empty.
// The change is from point a[i] moving from right to left.
// It forms 'count_in_prev_left' new same-pairs with left side, and loses 'count_in_right' same-pairs with right side.
// A bit tricky. Let's use the algebraic update.
if (i > 0) {
total_same_pairs = total_same_pairs + total_counts[v_new] - 2 * count_in_prev_left - 1;
} else { // i == 0
total_same_pairs = 1LL * (total_counts[v_new] - 1);
}
left_counts[v_new]++;
long long left_size = i + 1;
long long right_size = n - left_size;
weights[i] = left_size * right_size - total_same_pairs;
}
for (int i = 0; i < n - 1; ++i) {
cout << weights[i] << (i == n - 2 ? "" : " ");
}
cout << endl;
return 0;
}
import java.util.HashMap;
import java.util.Map;
import java.util.Scanner;
public class Main {
public static void main(String[] args) {
Scanner sc = new Scanner(System.in);
int n = sc.nextInt();
int[] a = new int[n];
Map<Integer, Integer> totalCounts = new HashMap<>();
for (int i = 0; i < n; i++) {
a[i] = sc.nextInt();
totalCounts.put(a[i], totalCounts.getOrDefault(a[i], 0) + 1);
}
Map<Integer, Integer> leftCounts = new HashMap<>();
long[] weights = new long[n - 1];
long totalSamePairs = 0;
for (int i = 0; i < n - 1; i++) {
int vNew = a[i];
long countInPrevLeft = leftCounts.getOrDefault(vNew, 0);
if (i > 0) {
totalSamePairs = totalSamePairs + totalCounts.get(vNew) - 2 * countInPrevLeft - 1;
} else { // i == 0
totalSamePairs = (long)totalCounts.get(vNew) - 1;
}
leftCounts.put(vNew, leftCounts.getOrDefault(vNew, 0) + 1);
long leftSize = i + 1;
long rightSize = n - leftSize;
weights[i] = leftSize * rightSize - totalSamePairs;
}
for (int i = 0; i < n - 1; i++) {
System.out.print(weights[i] + (i == n - 2 ? "" : " "));
}
System.out.println();
}
}
import sys
from collections import Counter
def solve():
try:
n_str = sys.stdin.readline()
if not n_str:
return
n = int(n_str)
a = list(map(int, sys.stdin.readline().split()))
total_counts = Counter(a)
left_counts = Counter()
weights = []
total_same_pairs = 0
for i in range(n - 1):
v_new = a[i]
count_in_prev_left = left_counts[v_new]
if i > 0:
total_same_pairs = total_same_pairs + total_counts[v_new] - 2 * count_in_prev_left - 1
else: # i == 0
total_same_pairs = total_counts[v_new] - 1
left_counts[v_new] += 1
left_size = i + 1
right_size = n - left_size
weight = left_size * right_size - total_same_pairs
weights.append(weight)
print(*weights)
except (IOError, ValueError):
return
solve()
算法及复杂度
-
算法:计数、哈希表、动态规划(递推)
-
时间复杂度:
或
。预处理统计总频率需要
(使用
map) 或(使用
unordered_map)。之后,计算n-1个边权,每一步的递推和更新哈希表需要或
的时间。所以总时间复杂度由哈希表的实现决定。
-
空间复杂度:
,其中
U是整个数组中不同点权的数量。用于存储频率哈希表。

京公网安备 11010502036488号