题目链接
题目描述
给定一个由 个节点构成的树形网络。每个节点
转发信息会产生一个延迟,延迟时间等于该节点的度数(即与之直接相连的边的数量)。信息在两点
之间的传递时间,定义为
之间唯一路径上所有节点(包括
和
)的延迟时间之和。
你需要处理 次独立的查询,每次给定两个节点
和
,求出它们之间的信息传递最短时间。
解题思路
由于网络是树形结构,任意两点 之间的路径是唯一的,因此“最短时间”就是这条唯一路径上的权值(延迟)之和。节点
的权值等于其度数
deg[i]
。
对每次查询都进行一次 DFS/BFS 来寻找路径并求和,时间复杂度为 ,会超时。这是一个典型的树上路径问题,可以使用最近公共祖先(LCA) 结合倍增(Binary Lifting) 来高效解决。
算法的核心思想是:任意节点 之间的路径都可以被它们的最近公共祖先
分为两段:
和
。因此,路径
的总权值和为:
path_sum(u, v) = path_sum(u, l) + path_sum(v, l) - deg[l]
(因为 在两条路径中都被计算了一次,所以要减去一次它的权值)。
为了快速计算任意节点到其某个祖先的路径和,我们可以进行预处理:
-
预处理:
- 首先,根据输入的边构建邻接表,并计算所有节点的度数
deg[i]
。 - 从任意节点(如节点1)作为根,进行一次深度优先搜索(DFS)。在DFS中,为每个节点
计算:
- 深度
depth[u]
。 - 父节点
up[0][u]
。 - 从根到节点
的路径权值和
dist[u]
,其递推式为dist[u] = dist[parent] + deg[u]
。
- 深度
- 利用
up[0]
数组和递推关系up[p][u] = up[p-1][up[p-1][u]]
构建完整的倍增表,用于快速查询LCA。此过程复杂度为。
- 首先,根据输入的边构建邻接表,并计算所有节点的度数
-
查询: 对于每次查询
:
- 使用倍增表在
时间内找到
。
- 利用预计算的
dist
数组,在时间内计算路径总和。路径
的权值和为
dist[u] - dist[l] + deg[l]
。因此,总路径和的计算公式为:total_sum = (dist[u] - dist[l] + deg[l]) + (dist[v] - dist[l] + deg[l]) - deg[l]
化简后得到:total_sum = dist[u] + dist[v] - 2 * dist[l] + deg[l]
- 如果查询的
是同一个节点,该公式也成立,结果为
deg[u]
。
- 使用倍增表在
代码
#include <iostream>
#include <vector>
#include <numeric>
using namespace std;
const int MAXN = 100005;
const int LOGN = 18;
vector<int> adj[MAXN];
int deg[MAXN];
long long dist[MAXN];
int depth[MAXN];
int up[LOGN][MAXN];
int n, m;
void dfs(int u, int p, int d, long long current_dist) {
depth[u] = d;
up[0][u] = p;
dist[u] = current_dist + deg[u];
for (int v : adj[u]) {
if (v != p) {
dfs(v, u, d + 1, dist[u]);
}
}
}
int lca(int u, int v) {
if (depth[u] < depth[v]) swap(u, v);
for (int i = LOGN - 1; i >= 0; --i) {
if (depth[u] - (1 << i) >= depth[v]) {
u = up[i][u];
}
}
if (u == v) return u;
for (int i = LOGN - 1; i >= 0; --i) {
if (up[i][u] != up[i][v]) {
u = up[i][u];
v = up[i][v];
}
}
return up[0][u];
}
int main() {
cin >> n >> m;
for (int i = 0; i < n - 1; ++i) {
int u, v;
cin >> u >> v;
adj[u].push_back(v);
adj[v].push_back(u);
deg[u]++;
deg[v]++;
}
dfs(1, 0, 1, 0);
for (int i = 1; i < LOGN; ++i) {
for (int j = 1; j <= n; ++j) {
up[i][j] = up[i - 1][up[i - 1][j]];
}
}
for (int i = 0; i < m; ++i) {
int u, v;
cin >> u >> v;
if (u == v) {
cout << deg[u] << "\n";
continue;
}
int l = lca(u, v);
long long result = dist[u] + dist[v] - 2 * dist[l] + deg[l];
cout << result << "\n";
}
return 0;
}
import java.util.ArrayList;
import java.util.Scanner;
public class Main {
static final int MAXN = 100005;
static final int LOGN = 18;
static ArrayList<Integer>[] adj = new ArrayList[MAXN];
static int[] deg = new int[MAXN];
static long[] dist = new long[MAXN];
static int[] depth = new int[MAXN];
static int[][] up = new int[LOGN][MAXN];
static int n, m;
static void dfs(int u, int p, int d, long currentDist) {
depth[u] = d;
up[0][u] = p;
dist[u] = currentDist; // dist[parent] is passed as currentDist
for (int v : adj[u]) {
if (v != p) {
dfs(v, u, d + 1, dist[u] + deg[u]);
}
}
}
// A slightly different DFS to calculate dist correctly
static void dfs_dist(int u, int p) {
for(int v : adj[u]) {
if (v != p) {
dist[v] = dist[u] + deg[v];
dfs_dist(v, u);
}
}
}
static int lca(int u, int v) {
if (depth[u] < depth[v]) {
int temp = u; u = v; v = temp;
}
for (int i = LOGN - 1; i >= 0; i--) {
if (depth[u] - (1 << i) >= depth[v]) {
u = up[i][u];
}
}
if (u == v) return u;
for (int i = LOGN - 1; i >= 0; i--) {
if (up[i][u] != up[i][v]) {
u = up[i][u];
v = up[i][v];
}
}
return up[0][u];
}
public static void main(String[] args) {
Scanner sc = new Scanner(System.in);
n = sc.nextInt();
m = sc.nextInt();
for (int i = 1; i <= n; i++) adj[i] = new ArrayList<>();
for (int i = 0; i < n - 1; i++) {
int u = sc.nextInt();
int v = sc.nextInt();
adj[u].add(v);
adj[v].add(u);
deg[u]++;
deg[v]++;
}
dfs(1, 0, 1, 0); // DFS for depth and parents
dist[1] = deg[1];
dfs_dist(1, 0); // DFS to calculate path sums from root
for (int i = 1; i < LOGN; i++) {
for (int j = 1; j <= n; j++) {
up[i][j] = up[i - 1][up[i - 1][j]];
}
}
for (int i = 0; i < m; i++) {
int u = sc.nextInt();
int v = sc.nextInt();
if (u == v) {
System.out.println(deg[u]);
continue;
}
int l = lca(u, v);
long result = dist[u] + dist[v] - 2 * dist[l] + deg[l];
System.out.println(result);
}
}
}
import sys
sys.setrecursionlimit(100005)
def solve():
n, m = map(int, input().split())
MAXN = n + 1
LOGN = (n).bit_length()
adj = [[] for _ in range(MAXN)]
deg = [0] * MAXN
for _ in range(n - 1):
u, v = map(int, input().split())
adj[u].append(v)
adj[v].append(u)
deg[u] += 1
deg[v] += 1
depth = [0] * MAXN
up = [[0] * MAXN for _ in range(LOGN)]
dist = [0] * MAXN
# DFS to build parent table, depths and path sums from root
q = [(1, 0, 1, 0)] # u, p, d, current_dist
visited = {1}
head = 0
while head < len(q):
u, p, d, current_dist = q[head]
head += 1
depth[u] = d
up[0][u] = p
dist[u] = current_dist + deg[u]
for v in adj[u]:
if v != p:
visited.add(v)
q.append((v, u, d + 1, dist[u]))
for i in range(1, LOGN):
for j in range(1, n + 1):
up[i][j] = up[i - 1][up[i - 1][j]]
def lca(u, v):
if depth[u] < depth[v]:
u, v = v, u
for i in range(LOGN - 1, -1, -1):
if depth[u] - (1 << i) >= depth[v]:
u = up[i][u]
if u == v:
return u
for i in range(LOGN - 1, -1, -1):
if up[i][u] != up[i][v]:
u = up[i][u]
v = up[i][v]
return up[0][u]
results = []
for _ in range(m):
u, v = map(int, input().split())
if u == v:
results.append(str(deg[u]))
continue
l = lca(u, v)
result = dist[u] + dist[v] - 2 * dist[l] + deg[l]
results.append(str(result))
print("\n".join(results))
solve()
算法及复杂度
- 算法:最近公共祖先(LCA) + 倍增法
- 时间复杂度:
。预处理阶段需要
,后续的
次查询每次需要
。
- 空间复杂度:
,主要用于存储倍增表和其它辅助数组。