题目链接
题目描述
给定一棵由 个节点组成的、以
为根的多叉树。有
次询问,每次给定两个节点
和
,你需要求出这两个节点的最近公共祖先(LCA)。
最近公共祖先(LCA):指在树中离两个节点最近的、且同时是这两个节点的祖先的节点。
解题思路
本题是最近公共祖先(LCA)的模板题。由于节点和查询数量都很大,对于每次查询,从一个节点暴力向上走到根,并记录路径,再从另一个节点向上走直到碰到已记录路径的节点,这种朴素算法的时间复杂度为 ,会超时。
解决此问题的标准高效算法是倍增法(Binary Lifting)。该算法分为预处理和查询两个阶段。
-
预处理
- 深度优先搜索(DFS):首先,我们从根节点
开始对整棵树进行一次 DFS。在遍历过程中,我们可以确定每个节点的深度
depth[u]
和它的直接父节点up[0][u]
(即的
级祖先)。
- 构建倍增表:我们创建一个二维数组
up[p][u]
,用于存储节点的第
个祖先。根据
up[0][u]
,我们可以用动态规划的思想来填充整个表。递推关系是:的第
个祖先,就是
的第
个祖先的第
个祖先。即
up[p][u] = up[p-1][up[p-1][u]]
。 预处理阶段的总时间复杂度为。
- 深度优先搜索(DFS):首先,我们从根节点
-
查询
对于每次查询
,我们利用预处理好的倍增表来快速定位LCA。
- 将节点提到同一深度:首先,比较
和
的深度,将深度较大的节点(假设是
)向上提升,直到它的深度和
相同。这个提升的过程可以用倍增来加速,总能在
时间内完成。
- 判断是否重合:如果此时
和
已经是同一个节点,那么它就是LCA,查询结束。
- 同步向上跳跃:如果
和
不相同,我们就让它们同步地、一步一步地向上跳。为了效率,我们从大到小尝试跳跃的步长(
)。只要跳跃后两者没有相遇(即
up[p][u] != up[p][v]
),我们就执行这次跳跃。这样可以保证我们最大限度地接近LCA,而又不会越过它。 - 找到LCA:当这个跳跃过程结束后,
和
所在的节点就是LCA的两个直接子节点。因此,它们的父节点
up[0][u]
(或up[0][v]
)就是所求的LCA。
每次查询的时间复杂度为
。
- 将节点提到同一深度:首先,比较
代码
#include <iostream>
#include <vector>
#include <algorithm>
#include <cmath>
using namespace std;
const int MAXN = 500005;
const int LOGN = 20;
vector<int> adj[MAXN];
int depth[MAXN];
int up[LOGN][MAXN];
int n, m, s;
void dfs(int u, int p, int d) {
depth[u] = d;
up[0][u] = p;
for (int v : adj[u]) {
if (v != p) {
dfs(v, u, d + 1);
}
}
}
int lca(int u, int v) {
if (depth[u] < depth[v]) {
swap(u, v);
}
int diff = depth[u] - depth[v];
for (int i = LOGN - 1; i >= 0; --i) {
if ((diff >> i) & 1) {
u = up[i][u];
}
}
if (u == v) {
return u;
}
for (int i = LOGN - 1; i >= 0; --i) {
if (up[i][u] != up[i][v]) {
u = up[i][u];
v = up[i][v];
}
}
return up[0][u];
}
int main() {
cin >> n >> m >> s;
for (int i = 0; i < n - 1; ++i) {
int u, v;
cin >> u >> v;
adj[u].push_back(v);
adj[v].push_back(u);
}
dfs(s, 0, 1);
for (int i = 1; i < LOGN; ++i) {
for (int j = 1; j <= n; ++j) {
up[i][j] = up[i - 1][up[i - 1][j]];
}
}
for (int i = 0; i < m; ++i) {
int u, v;
cin >> u >> v;
cout << lca(u, v) << "\n";
}
return 0;
}
import java.util.ArrayList;
import java.util.Scanner;
public class Main {
static final int MAXN = 500005;
static final int LOGN = 20;
static ArrayList<Integer>[] adj = new ArrayList[MAXN];
static int[] depth = new int[MAXN];
static int[][] up = new int[LOGN][MAXN];
static int n, m, s;
static void dfs(int u, int p, int d) {
depth[u] = d;
up[0][u] = p;
for (int v : adj[u]) {
if (v != p) {
dfs(v, u, d + 1);
}
}
}
static int lca(int u, int v) {
if (depth[u] < depth[v]) {
int temp = u;
u = v;
v = temp;
}
int diff = depth[u] - depth[v];
for (int i = LOGN - 1; i >= 0; i--) {
if (((diff >> i) & 1) == 1) {
u = up[i][u];
}
}
if (u == v) {
return u;
}
for (int i = LOGN - 1; i >= 0; i--) {
if (up[i][u] != up[i][v]) {
u = up[i][u];
v = up[i][v];
}
}
return up[0][u];
}
public static void main(String[] args) {
Scanner sc = new Scanner(System.in);
n = sc.nextInt();
m = sc.nextInt();
s = sc.nextInt();
for (int i = 1; i <= n; i++) {
adj[i] = new ArrayList<>();
}
for (int i = 0; i < n - 1; i++) {
int u = sc.nextInt();
int v = sc.nextInt();
adj[u].add(v);
adj[v].add(u);
}
dfs(s, 0, 1);
for (int i = 1; i < LOGN; i++) {
for (int j = 1; j <= n; j++) {
up[i][j] = up[i - 1][up[i - 1][j]];
}
}
for (int i = 0; i < m; i++) {
int u = sc.nextInt();
int v = sc.nextInt();
System.out.println(lca(u, v));
}
}
}
import sys
# 增大递归深度限制以处理深度较大的树
sys.setrecursionlimit(500005)
def solve():
n, m, s = map(int, input().split())
MAXN = n + 1
LOGN = (n).bit_length()
adj = [[] for _ in range(MAXN)]
for _ in range(n - 1):
u, v = map(int, input().split())
adj[u].append(v)
adj[v].append(u)
depth = [0] * MAXN
up = [[0] * MAXN for _ in range(LOGN)]
def dfs(u, p, d):
depth[u] = d
up[0][u] = p
for v in adj[u]:
if v != p:
dfs(v, u, d + 1)
dfs(s, 0, 1)
for i in range(1, LOGN):
for j in range(1, n + 1):
up[i][j] = up[i - 1][up[i - 1][j]]
def lca(u, v):
if depth[u] < depth[v]:
u, v = v, u
diff = depth[u] - depth[v]
for i in range(LOGN - 1, -1, -1):
if (diff >> i) & 1:
u = up[i][u]
if u == v:
return u
for i in range(LOGN - 1, -1, -1):
if up[i][u] != up[i][v]:
u = up[i][u]
v = up[i][v]
return up[0][u]
results = []
for _ in range(m):
u, v = map(int, input().split())
results.append(str(lca(u, v)))
print("\n".join(results))
solve()
算法及复杂度
- 算法:倍增法 (Binary Lifting)
- 时间复杂度:
。预处理阶段需要
,后续的
次查询每次需要
。
- 空间复杂度:
,主要用于存储倍增表
up
。