题目链接
题目描述
给定一个由 个节点构成的树形结构。有
次询问,每次询问给定四个节点
。你需要判断从
到
的最短路径,与从
到
的最短路径,是否在树上存在公共节点。
解题思路
在树中,任意两点之间的最短路径是唯一的。因此,问题就是判断 path(a, b)
和 path(c, d)
这两条唯一路径是否相交。
对每次查询都遍历路径并检查交点,效率太低。我们可以利用最近公共祖先(LCA) 的性质来设计一个高效的算法。
核心判定条件
两条路径 path(a, b)
和 path(c, d)
相交的充要条件是:其中一条路径的LCA,在另一条路径上。
- 令
lca_ab = lca(a, b)
,lca_cd = lca(c, d)
。 - 两条路径相交
(
lca_ab
在path(c, d)
上)(
lca_cd
在path(a, b)
上)。
如何判断点在路径上
现在问题转化为如何高效判断一个点 是否在路径
path(u, v)
上。这同样可以利用LCA来判断。
一个节点 在路径
path(u, v)
上,等价于 到
的路径和
到
的路径是简单的(没有重合,只在
点交汇),并且这两段路径长度之和等于
到
的总路径长。
在树中,这可以简化为:
dist(u, p) + dist(p, v) == dist(u, v)
其中 dist(x, y)
表示节点 之间的距离(边数),可以通过深度和LCA计算:
dist(x, y) = depth[x] + depth[y] - 2 * depth[lca(x, y)]
虽然这个方法可行,但有一个更简洁的逻辑判断:
一个节点 在路径
path(u, v)
上
lca(u, p) == p
且 lca(v, p) == lca(u, v)
,或者 lca(v, p) == p
且 lca(u, p) == lca(u, v)
。
算法流程
- 预处理:
- 对整棵树进行一次DFS,计算每个节点的深度
depth
和父节点up[0]
。 - 利用倍增法构建
up[p][u]
表,用于查询LCA。
- 对整棵树进行一次DFS,计算每个节点的深度
- 查询:
- 对于每次查询
(a, b, c, d)
,计算lca_ab = lca(a, b)
和lca_cd = lca(c, d)
。 - 检查
is_on_path(lca_ab, c, d)
和is_on_path(lca_cd, a, b)
是否至少有一个为真。
- 对于每次查询
代码
#include <iostream>
#include <vector>
#include <algorithm>
using namespace std;
const int MAXN = 100005;
const int LOGN = 18;
vector<int> adj[MAXN];
int depth[MAXN];
int up[LOGN][MAXN];
int n, m;
void dfs(int u, int p, int d) {
depth[u] = d;
up[0][u] = p;
for (int v : adj[u]) {
if (v != p) {
dfs(v, u, d + 1);
}
}
}
int lca(int u, int v) {
if (depth[u] < depth[v]) swap(u, v);
for (int i = LOGN - 1; i >= 0; --i) {
if (depth[u] - (1 << i) >= depth[v]) {
u = up[i][u];
}
}
if (u == v) return u;
for (int i = LOGN - 1; i >= 0; --i) {
if (up[i][u] != up[i][v]) {
u = up[i][u];
v = up[i][v];
}
}
return up[0][u];
}
bool is_on_path(int p, int u, int v) {
int luv = lca(u, v);
return (lca(u, p) == p && lca(p, v) == luv) || (lca(v, p) == p && lca(p, u) == luv);
}
int main() {
cin >> n >> m;
for (int i = 0; i < n - 1; ++i) {
int u, v;
cin >> u >> v;
adj[u].push_back(v);
adj[v].push_back(u);
}
dfs(1, 0, 1);
up[0][0] = 0;
for (int i = 1; i < LOGN; ++i) {
for (int j = 1; j <= n; ++j) {
up[i][j] = up[i - 1][up[i - 1][j]];
}
}
for (int i = 0; i < m; ++i) {
int a, b, c, d;
cin >> a >> b >> c >> d;
int l_ab = lca(a, b);
int l_cd = lca(c, d);
if (is_on_path(l_ab, c, d) || is_on_path(l_cd, a, b)) {
cout << "Yes\n";
} else {
cout << "No\n";
}
}
return 0;
}
import java.util.ArrayList;
import java.util.Scanner;
public class Main {
static final int MAXN = 100005;
static final int LOGN = 18;
static ArrayList<Integer>[] adj = new ArrayList[MAXN];
static int[] depth = new int[MAXN];
static int[][] up = new int[LOGN][MAXN];
static int n, m;
static void dfs(int u, int p, int d) {
depth[u] = d;
up[0][u] = p;
for (int v : adj[u]) {
if (v != p) {
dfs(v, u, d + 1);
}
}
}
static int lca(int u, int v) {
if (depth[u] < depth[v]) {
int temp = u; u = v; v = temp;
}
for (int i = LOGN - 1; i >= 0; i--) {
if (depth[u] - (1 << i) >= depth[v]) {
u = up[i][u];
}
}
if (u == v) return u;
for (int i = LOGN - 1; i >= 0; i--) {
if (up[i][u] != up[i][v]) {
u = up[i][u];
v = up[i][v];
}
}
return up[0][u];
}
static boolean isOnPath(int p, int u, int v) {
int luv = lca(u, v);
return (lca(u, p) == p && lca(p, v) == luv) || (lca(v, p) == p && lca(p, u) == luv);
}
public static void main(String[] args) {
Scanner sc = new Scanner(System.in);
n = sc.nextInt();
m = sc.nextInt();
for (int i = 1; i <= n; i++) adj[i] = new ArrayList<>();
for (int i = 0; i < n - 1; i++) {
int u = sc.nextInt();
int v = sc.nextInt();
adj[u].add(v);
adj[v].add(u);
}
dfs(1, 0, 1);
up[0][0] = 0;
for (int i = 1; i < LOGN; i++) {
for (int j = 1; j <= n; j++) {
up[i][j] = up[i - 1][up[i - 1][j]];
}
}
for (int i = 0; i < m; i++) {
int a = sc.nextInt();
int b = sc.nextInt();
int c = sc.nextInt();
int d = sc.nextInt();
int l_ab = lca(a, b);
int l_cd = lca(c, d);
if (isOnPath(l_ab, c, d) || isOnPath(l_cd, a, b)) {
System.out.println("Yes");
} else {
System.out.println("No");
}
}
}
}
import sys
sys.setrecursionlimit(100005)
def solve():
n, m = map(int, input().split())
MAXN = n + 1
LOGN = (n).bit_length()
adj = [[] for _ in range(MAXN)]
for _ in range(n - 1):
u, v = map(int, input().split())
adj[u].append(v)
adj[v].append(u)
depth = [0] * MAXN
up = [[0] * MAXN for _ in range(LOGN)]
def dfs(u, p, d):
depth[u] = d
up[0][u] = p
for v in adj[u]:
if v != p:
dfs(v, u, d + 1)
dfs(1, 0, 1)
for i in range(1, LOGN):
for j in range(1, n + 1):
up[i][j] = up[i - 1][up[i - 1][j]]
def lca(u, v):
if depth[u] < depth[v]:
u, v = v, u
for i in range(LOGN - 1, -1, -1):
if depth[u] - (1 << i) >= depth[v]:
u = up[i][u]
if u == v:
return u
for i in range(LOGN - 1, -1, -1):
if up[i][u] != up[i][v]:
u = up[i][u]
v = up[i][v]
return up[0][u]
def is_on_path(p, u, v):
luv = lca(u, v)
return (lca(u, p) == p and lca(p, v) == luv) or \
(lca(v, p) == p and lca(p, u) == luv)
results = []
for _ in range(m):
a, b, c, d = map(int, input().split())
l_ab = lca(a, b)
l_cd = lca(c, d)
if is_on_path(l_ab, c, d) or is_on_path(l_cd, a, b):
results.append("Yes")
else:
results.append("No")
print("\n".join(results))
solve()
算法及复杂度
- 算法:最近公共祖先(LCA) + 倍增法
- 时间复杂度:
。预处理阶段需要
,后续的
次查询每次需要
。
- 空间复杂度:
,主要用于存储倍增表。