题目链接

链式边权

题目描述

n 个点连成一条链,编号从 1n。第 i 条边连接点 ii+1。 每个点 i 有一个点权 a[i]

i 条边的权重 w[i] 定义为:满足 x <= iy > i,并且 a[x] != a[y] 的点对 (x, y) 的数量。

给定所有点的点权,请求出所有 n-1 条边的权重。

解题思路

这是一个计数问题。对于每一条边 i,我们需要计算它左右两侧点权不同的点对数量。

1. 核心思想:正难则反

直接计算“不同”的点对数量比较复杂。我们可以换一个思路: w[i] = (总点对数) - (点权相同的点对数)

  • 总点对数:对于第 i 条边,其左侧有 i 个点(从1到i),右侧有 n-i 个点(从i+1到n)。所以,跨越这条边的总点对数量是 i * (n-i)
  • 点权相同的点对数:我们需要计算满足 x <= i, y > i, 且 a[x] == a[y] 的点对数量。

2. 高效计算相同点对: 的暴力解法

暴力计算每一条边的相同点对数,总时间复杂度会达到 。 对于第 i 条边,我们可以:

  1. 建立一个哈希表 left_counts 统计 a[0...i] 中每个点权的频率。
  2. 建立另一个哈希表 right_counts 统计 a[i+1...n-1] 中每个点权的频率。
  3. 遍历 left_counts,对于每个点权 v相同点对数 += left_counts[v] * right_counts[v]

对每条边都重复这个过程,总复杂度就是 ,在 N=100000 时会超时。

3. 算法优化:寻找递推关系

我们可以发现,当分割点从 i 移动到 i+1 时,只有点 a[i+1] 从右侧集合移动到了左侧集合。这意味着 SamePairs(i+1)SamePairs(i) 之间存在密切的联系。我们可以利用这个关系进行快速计算,而无需每次都重新统计。

sp(i) 为以第 i 条边为分割的相同点对数。 sp(i) = Σ (count_left(v, i) * count_right(v, i)) (对所有点权 v 求和)

当我们从边 i-1 移动到边 i 时,点 a[i] 从右侧移动到左侧。设 v_new = a[i]。 对于 v_new,它在左侧的数量增加了1,在右侧的数量减少了1。 对于其他 v != v_new,其在左右两侧的数量不变。

经过推导,我们可以得到 sp(i)sp(i-1) 的递推关系: sp(i) = sp(i-1) + count_right(v_new, i-1) - count_left(v_new, i-1) - 1

其中 count_right(v_new, i-1)v_new{a[i], ..., a[n-1]} 中的数量,count_left(v_new, i-1)v_new{a[0], ..., a[i-1]} 中的数量。

利用 count_right + count_left = total_count,公式可以进一步化为: sp(i) = sp(i-1) + total_count(v_new) - 2 * count_left(v_new, i-1) - 1

这个递推关系允许我们仅用上一步的结果和一些计数值,在 (哈希表) 或 (map) 的时间内计算出当前步的结果。

算法步骤总结 (优化后)

  1. 用哈希表 total_counts 统计数组 a 中所有数字的频率。
  2. 初始化一个空的哈希表 left_counts,用于记录当前分割点左侧的数字频率。
  3. 初始化 total_same_pairs = 0
  4. 循环 i0n-2 (对应第 0n-2 条边):
    • v_new = a[i]
    • count_in_prev_left = left_counts.get(v_new, 0)
    • 使用递推公式更新 total_same_pairs
      • 如果是第一次(i=0),total_same_pairs = 1 * (total_counts[v_new] - 1)
      • 否则,total_same_pairs = total_same_pairs + total_counts[v_new] - 2 * count_in_prev_left - 1
    • 更新 left_countsleft_counts[v_new]++
    • 计算权重 weights[i]
      • left_size = i + 1
      • right_size = n - left_size
      • weights[i] = left_size * right_size - total_same_pairs
  5. 输出 weights 数组。

代码

#include <iostream>
#include <vector>
#include <map>

using namespace std;

int main() {
    ios_base::sync_with_stdio(false);
    cin.tie(NULL);

    int n;
    cin >> n;

    vector<int> a(n);
    map<int, int> total_counts;
    for (int i = 0; i < n; ++i) {
        cin >> a[i];
        total_counts[a[i]]++;
    }

    map<int, int> left_counts;
    vector<long long> weights(n - 1);
    long long total_same_pairs = 0;

    for (int i = 0; i < n - 1; ++i) {
        int v_new = a[i];
        
        long long count_in_prev_left = 0;
        if (left_counts.count(v_new)) {
            count_in_prev_left = left_counts[v_new];
        }

        // sp(i) = sp(i-1) + total_counts(v) - 2*left_counts(v, i-1) - 1
        // Note: this recurrence applies for i > 0. For i = 0, we compute directly.
        // For simplicity, we can think of sp(-1) = 0, left_counts(-1) is empty.
        // The change is from point a[i] moving from right to left.
        // It forms 'count_in_prev_left' new same-pairs with left side, and loses 'count_in_right' same-pairs with right side.
        // A bit tricky. Let's use the algebraic update.
        if (i > 0) {
             total_same_pairs = total_same_pairs + total_counts[v_new] - 2 * count_in_prev_left - 1;
        } else { // i == 0
             total_same_pairs = 1LL * (total_counts[v_new] - 1);
        }
        
        left_counts[v_new]++;

        long long left_size = i + 1;
        long long right_size = n - left_size;
        weights[i] = left_size * right_size - total_same_pairs;
    }

    for (int i = 0; i < n - 1; ++i) {
        cout << weights[i] << (i == n - 2 ? "" : " ");
    }
    cout << endl;

    return 0;
}
import java.util.HashMap;
import java.util.Map;
import java.util.Scanner;

public class Main {
    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);
        int n = sc.nextInt();
        int[] a = new int[n];
        Map<Integer, Integer> totalCounts = new HashMap<>();

        for (int i = 0; i < n; i++) {
            a[i] = sc.nextInt();
            totalCounts.put(a[i], totalCounts.getOrDefault(a[i], 0) + 1);
        }

        Map<Integer, Integer> leftCounts = new HashMap<>();
        long[] weights = new long[n - 1];
        long totalSamePairs = 0;

        for (int i = 0; i < n - 1; i++) {
            int vNew = a[i];
            
            long countInPrevLeft = leftCounts.getOrDefault(vNew, 0);
            
            if (i > 0) {
                totalSamePairs = totalSamePairs + totalCounts.get(vNew) - 2 * countInPrevLeft - 1;
            } else { // i == 0
                totalSamePairs = (long)totalCounts.get(vNew) - 1;
            }

            leftCounts.put(vNew, leftCounts.getOrDefault(vNew, 0) + 1);

            long leftSize = i + 1;
            long rightSize = n - leftSize;
            weights[i] = leftSize * rightSize - totalSamePairs;
        }

        for (int i = 0; i < n - 1; i++) {
            System.out.print(weights[i] + (i == n - 2 ? "" : " "));
        }
        System.out.println();
    }
}
import sys
from collections import Counter

def solve():
    try:
        n_str = sys.stdin.readline()
        if not n_str:
            return
        n = int(n_str)
        a = list(map(int, sys.stdin.readline().split()))
        
        total_counts = Counter(a)
        left_counts = Counter()
        weights = []
        total_same_pairs = 0
        
        for i in range(n - 1):
            v_new = a[i]
            
            count_in_prev_left = left_counts[v_new]
            
            if i > 0:
                total_same_pairs = total_same_pairs + total_counts[v_new] - 2 * count_in_prev_left - 1
            else: # i == 0
                total_same_pairs = total_counts[v_new] - 1

            left_counts[v_new] += 1
            
            left_size = i + 1
            right_size = n - left_size
            
            weight = left_size * right_size - total_same_pairs
            weights.append(weight)
            
        print(*weights)

    except (IOError, ValueError):
        return

solve()

算法及复杂度

  • 算法:计数、哈希表、动态规划(递推)

  • 时间复杂度: 。预处理统计总频率需要 (使用 map) 或 (使用 unordered_map)。之后,计算 n-1 个边权,每一步的递推和更新哈希表需要 的时间。所以总时间复杂度由哈希表的实现决定。

  • 空间复杂度: ,其中 U 是整个数组中不同点权的数量。用于存储频率哈希表。