小红的加边距离

[题目链接](https://www.nowcoder.com/practice/b53afc9bbc214e4e8a465d85ed2563d9)

思路

关键观察:加边后的距离公式

题目允许添加边 ,条件是原树中存在节点 使得 都是原树的边(即 在原树中有公共邻居)。

加边策略:对于原树中任意两节点 ,若它们在原树中的距离 且存在公共邻居,则可以直接连边。而对于树中距离为 2 的两个节点,它们一定共享同一个中间节点作为公共邻居。因此,将所有可能的边都加完后,两点之间的新距离为:

$$

直觉上,新图中每一步可以"跳过"一个中间节点,因此相当于原距离除以 2 再向上取整。

样例验证:以示例树(5个点,边 1-2, 1-3, 2-4, 2-5)为例,节点 3 到节点 4 的原距离为 3,加边后距离为 ,与样例说明一致。

化简求和公式

设所有有序点对 )的答案为:

$$

利用恒等式 ,得:

$$

即:

$$

其中 是原树所有有序点对距离之和, 是距离为奇数的有序点对数量。

两个子问题的高效计算

子问题 1:树上所有点对距离之和

经典树形DP(贡献法):对树以节点 1 为根,每条边 将树分为两棵子树,大小分别为 。该边对所有点对距离之和的贡献(有序对)为:

$$

对所有非根节点 求和即得 ,时间复杂度

子问题 2:奇数距离点对数

在树中,。由于 是偶数,所以:

$$

因此, 的距离为奇数,当且仅当它们的深度奇偶性不同(一个在偶数层,一个在奇数层)。

设深度为偶数的节点有 个,深度为奇数的节点有 个(),则:

$$

(有序对,乘以 2)

样例演示

树:1-2, 1-3, 2-4, 2-5,以节点 1 为根:

  • 深度:(偶),(奇),(奇),(偶),(偶)

各边贡献:

  • 边 (1,2):,贡献
  • 边 (1,3):,贡献
  • 边 (2,4):,贡献
  • 边 (2,5):,贡献

复杂度

  • 时间:(一次 BFS 计算深度和子树大小)
  • 空间:

代码

C++

#include <iostream>
#include <vector>
#include <queue>
using namespace std;

int main() {
    ios_base::sync_with_stdio(false);
    cin.tie(NULL);

    int n;
    cin >> n;

    vector<vector<int>> adj(n + 1);
    for (int i = 0; i < n - 1; i++) {
        int u, v;
        cin >> u >> v;
        adj[u].push_back(v);
        adj[v].push_back(u);
    }

    vector<long long> subtree_size(n + 1, 0);
    vector<long long> depth(n + 1, -1);
    vector<int> order;
    vector<int> parent(n + 1, 0);

    depth[1] = 0;
    queue<int> q;
    q.push(1);
    while (!q.empty()) {
        int u = q.front(); q.pop();
        order.push_back(u);
        subtree_size[u] = 1;
        for (int v : adj[u]) {
            if (depth[v] == -1) {
                depth[v] = depth[u] + 1;
                parent[v] = u;
                q.push(v);
            }
        }
    }

    for (int i = order.size() - 1; i >= 0; i--) {
        int u = order[i];
        if (parent[u] != 0) {
            subtree_size[parent[u]] += subtree_size[u];
        }
    }

    long long sum_dist = 0;
    for (int v = 2; v <= n; v++) {
        long long s = subtree_size[v];
        sum_dist += 2LL * s * (n - s);
    }

    long long even_cnt = 0, odd_cnt = 0;
    for (int i = 1; i <= n; i++) {
        if (depth[i] % 2 == 0) even_cnt++;
        else odd_cnt++;
    }
    long long count_odd = 2LL * even_cnt * odd_cnt;

    cout << (sum_dist + count_odd) / 2 << "\n";
    return 0;
}

Java

import java.util.*;
import java.io.*;

public class Main {
    public static void main(String[] args) throws IOException {
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
        int n = Integer.parseInt(br.readLine().trim());

        List<List<Integer>> adj = new ArrayList<>();
        for (int i = 0; i <= n; i++) adj.add(new ArrayList<>());

        for (int i = 0; i < n - 1; i++) {
            StringTokenizer st = new StringTokenizer(br.readLine());
            int u = Integer.parseInt(st.nextToken());
            int v = Integer.parseInt(st.nextToken());
            adj.get(u).add(v);
            adj.get(v).add(u);
        }

        long[] depth = new long[n + 1];
        long[] subtreeSize = new long[n + 1];
        int[] parent = new int[n + 1];
        int[] order = new int[n];

        Arrays.fill(depth, -1);
        depth[1] = 0;
        Queue<Integer> q = new LinkedList<>();
        q.add(1);
        int idx = 0;
        while (!q.isEmpty()) {
            int u = q.poll();
            order[idx++] = u;
            subtreeSize[u] = 1;
            for (int v : adj.get(u)) {
                if (depth[v] == -1) {
                    depth[v] = depth[u] + 1;
                    parent[v] = u;
                    q.add(v);
                }
            }
        }

        for (int i = n - 1; i >= 0; i--) {
            int u = order[i];
            if (parent[u] != 0) {
                subtreeSize[parent[u]] += subtreeSize[u];
            }
        }

        long sumDist = 0;
        for (int v = 2; v <= n; v++) {
            long s = subtreeSize[v];
            sumDist += 2L * s * (n - s);
        }

        long evenCnt = 0, oddCnt = 0;
        for (int i = 1; i <= n; i++) {
            if (depth[i] % 2 == 0) evenCnt++;
            else oddCnt++;
        }
        long countOdd = 2L * evenCnt * oddCnt;

        System.out.println((sumDist + countOdd) / 2);
    }
}