小红的加边距离
[题目链接](https://www.nowcoder.com/practice/b53afc9bbc214e4e8a465d85ed2563d9)
思路
关键观察:加边后的距离公式
题目允许添加边 ,条件是原树中存在节点
使得
和
都是原树的边(即
在原树中有公共邻居)。
加边策略:对于原树中任意两节点 ,若它们在原树中的距离
且存在公共邻居,则可以直接连边。而对于树中距离为 2 的两个节点,它们一定共享同一个中间节点作为公共邻居。因此,将所有可能的边都加完后,两点之间的新距离为:
$$
直觉上,新图中每一步可以"跳过"一个中间节点,因此相当于原距离除以 2 再向上取整。
样例验证:以示例树(5个点,边 1-2, 1-3, 2-4, 2-5)为例,节点 3 到节点 4 的原距离为 3,加边后距离为 ,与样例说明一致。
化简求和公式
设所有有序点对 (
)的答案为:
$$
利用恒等式 ,得:
$$
即:
$$
其中 是原树所有有序点对距离之和,
是距离为奇数的有序点对数量。
两个子问题的高效计算
子问题 1:树上所有点对距离之和
经典树形DP(贡献法):对树以节点 1 为根,每条边 将树分为两棵子树,大小分别为
和
。该边对所有点对距离之和的贡献(有序对)为:
$$
对所有非根节点 求和即得
,时间复杂度
。
子问题 2:奇数距离点对数
在树中,。由于
是偶数,所以:
$$
因此, 和
的距离为奇数,当且仅当它们的深度奇偶性不同(一个在偶数层,一个在奇数层)。
设深度为偶数的节点有 个,深度为奇数的节点有
个(
),则:
$$
(有序对,乘以 2)
样例演示
树:1-2, 1-3, 2-4, 2-5,以节点 1 为根:
- 深度:
(偶),
(奇),
(奇),
(偶),
(偶)
,
各边贡献:
- 边 (1,2):
,贡献
- 边 (1,3):
,贡献
- 边 (2,4):
,贡献
- 边 (2,5):
,贡献
✓
复杂度
- 时间:
(一次 BFS 计算深度和子树大小)
- 空间:
代码
C++
#include <iostream>
#include <vector>
#include <queue>
using namespace std;
int main() {
ios_base::sync_with_stdio(false);
cin.tie(NULL);
int n;
cin >> n;
vector<vector<int>> adj(n + 1);
for (int i = 0; i < n - 1; i++) {
int u, v;
cin >> u >> v;
adj[u].push_back(v);
adj[v].push_back(u);
}
vector<long long> subtree_size(n + 1, 0);
vector<long long> depth(n + 1, -1);
vector<int> order;
vector<int> parent(n + 1, 0);
depth[1] = 0;
queue<int> q;
q.push(1);
while (!q.empty()) {
int u = q.front(); q.pop();
order.push_back(u);
subtree_size[u] = 1;
for (int v : adj[u]) {
if (depth[v] == -1) {
depth[v] = depth[u] + 1;
parent[v] = u;
q.push(v);
}
}
}
for (int i = order.size() - 1; i >= 0; i--) {
int u = order[i];
if (parent[u] != 0) {
subtree_size[parent[u]] += subtree_size[u];
}
}
long long sum_dist = 0;
for (int v = 2; v <= n; v++) {
long long s = subtree_size[v];
sum_dist += 2LL * s * (n - s);
}
long long even_cnt = 0, odd_cnt = 0;
for (int i = 1; i <= n; i++) {
if (depth[i] % 2 == 0) even_cnt++;
else odd_cnt++;
}
long long count_odd = 2LL * even_cnt * odd_cnt;
cout << (sum_dist + count_odd) / 2 << "\n";
return 0;
}
Java
import java.util.*;
import java.io.*;
public class Main {
public static void main(String[] args) throws IOException {
BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
int n = Integer.parseInt(br.readLine().trim());
List<List<Integer>> adj = new ArrayList<>();
for (int i = 0; i <= n; i++) adj.add(new ArrayList<>());
for (int i = 0; i < n - 1; i++) {
StringTokenizer st = new StringTokenizer(br.readLine());
int u = Integer.parseInt(st.nextToken());
int v = Integer.parseInt(st.nextToken());
adj.get(u).add(v);
adj.get(v).add(u);
}
long[] depth = new long[n + 1];
long[] subtreeSize = new long[n + 1];
int[] parent = new int[n + 1];
int[] order = new int[n];
Arrays.fill(depth, -1);
depth[1] = 0;
Queue<Integer> q = new LinkedList<>();
q.add(1);
int idx = 0;
while (!q.isEmpty()) {
int u = q.poll();
order[idx++] = u;
subtreeSize[u] = 1;
for (int v : adj.get(u)) {
if (depth[v] == -1) {
depth[v] = depth[u] + 1;
parent[v] = u;
q.add(v);
}
}
}
for (int i = n - 1; i >= 0; i--) {
int u = order[i];
if (parent[u] != 0) {
subtreeSize[parent[u]] += subtreeSize[u];
}
}
long sumDist = 0;
for (int v = 2; v <= n; v++) {
long s = subtreeSize[v];
sumDist += 2L * s * (n - s);
}
long evenCnt = 0, oddCnt = 0;
for (int i = 1; i <= n; i++) {
if (depth[i] % 2 == 0) evenCnt++;
else oddCnt++;
}
long countOdd = 2L * evenCnt * oddCnt;
System.out.println((sumDist + countOdd) / 2);
}
}

京公网安备 11010502036488号