题目链接

PEEK60 旺仔哥哥的约会

题目描述

给定一个由 个节点构成的树形结构。有 次询问,每次询问给定四个节点 。你需要判断从 的最短路径,与从 的最短路径,是否在树上存在公共节点。

解题思路

在树中,任意两点之间的最短路径是唯一的。因此,问题就是判断 path(a, b)path(c, d) 这两条唯一路径是否相交。

对每次查询都遍历路径并检查交点,效率太低。我们可以利用最近公共祖先(LCA) 的性质来设计一个高效的算法。

核心判定条件

两条路径 path(a, b)path(c, d) 相交的充要条件是:其中一条路径的LCA,在另一条路径上

  • lca_ab = lca(a, b)lca_cd = lca(c, d)
  • 两条路径相交 (lca_abpath(c, d) 上) (lca_cdpath(a, b) 上)。

如何判断点在路径上

现在问题转化为如何高效判断一个点 是否在路径 path(u, v) 上。这同样可以利用LCA来判断。 一个节点 在路径 path(u, v) 上,等价于 的路径和 的路径是简单的(没有重合,只在 点交汇),并且这两段路径长度之和等于 的总路径长。

在树中,这可以简化为:

dist(u, p) + dist(p, v) == dist(u, v)

其中 dist(x, y) 表示节点 之间的距离(边数),可以通过深度和LCA计算:

dist(x, y) = depth[x] + depth[y] - 2 * depth[lca(x, y)]

虽然这个方法可行,但有一个更简洁的逻辑判断:

一个节点 在路径 path(u, v) lca(u, p) == plca(v, p) == lca(u, v),或者 lca(v, p) == plca(u, p) == lca(u, v)

算法流程

  1. 预处理
    • 对整棵树进行一次DFS,计算每个节点的深度 depth 和父节点 up[0]
    • 利用倍增法构建 up[p][u] 表,用于 查询LCA。
  2. 查询
    • 对于每次查询 (a, b, c, d),计算 lca_ab = lca(a, b)lca_cd = lca(c, d)
    • 检查 is_on_path(lca_ab, c, d)is_on_path(lca_cd, a, b) 是否至少有一个为真。

代码

#include <iostream>
#include <vector>
#include <algorithm>

using namespace std;

const int MAXN = 100005;
const int LOGN = 18;

vector<int> adj[MAXN];
int depth[MAXN];
int up[LOGN][MAXN];
int n, m;

void dfs(int u, int p, int d) {
    depth[u] = d;
    up[0][u] = p;
    for (int v : adj[u]) {
        if (v != p) {
            dfs(v, u, d + 1);
        }
    }
}

int lca(int u, int v) {
    if (depth[u] < depth[v]) swap(u, v);
    for (int i = LOGN - 1; i >= 0; --i) {
        if (depth[u] - (1 << i) >= depth[v]) {
            u = up[i][u];
        }
    }
    if (u == v) return u;
    for (int i = LOGN - 1; i >= 0; --i) {
        if (up[i][u] != up[i][v]) {
            u = up[i][u];
            v = up[i][v];
        }
    }
    return up[0][u];
}

bool is_on_path(int p, int u, int v) {
    int luv = lca(u, v);
    return (lca(u, p) == p && lca(p, v) == luv) || (lca(v, p) == p && lca(p, u) == luv);
}

int main() {
    cin >> n >> m;
    for (int i = 0; i < n - 1; ++i) {
        int u, v;
        cin >> u >> v;
        adj[u].push_back(v);
        adj[v].push_back(u);
    }

    dfs(1, 0, 1);
    up[0][0] = 0;

    for (int i = 1; i < LOGN; ++i) {
        for (int j = 1; j <= n; ++j) {
            up[i][j] = up[i - 1][up[i - 1][j]];
        }
    }

    for (int i = 0; i < m; ++i) {
        int a, b, c, d;
        cin >> a >> b >> c >> d;
        int l_ab = lca(a, b);
        int l_cd = lca(c, d);
        if (is_on_path(l_ab, c, d) || is_on_path(l_cd, a, b)) {
            cout << "Yes\n";
        } else {
            cout << "No\n";
        }
    }

    return 0;
}
import java.util.ArrayList;
import java.util.Scanner;

public class Main {
    static final int MAXN = 100005;
    static final int LOGN = 18;
    static ArrayList<Integer>[] adj = new ArrayList[MAXN];
    static int[] depth = new int[MAXN];
    static int[][] up = new int[LOGN][MAXN];
    static int n, m;

    static void dfs(int u, int p, int d) {
        depth[u] = d;
        up[0][u] = p;
        for (int v : adj[u]) {
            if (v != p) {
                dfs(v, u, d + 1);
            }
        }
    }

    static int lca(int u, int v) {
        if (depth[u] < depth[v]) {
            int temp = u; u = v; v = temp;
        }
        for (int i = LOGN - 1; i >= 0; i--) {
            if (depth[u] - (1 << i) >= depth[v]) {
                u = up[i][u];
            }
        }
        if (u == v) return u;
        for (int i = LOGN - 1; i >= 0; i--) {
            if (up[i][u] != up[i][v]) {
                u = up[i][u];
                v = up[i][v];
            }
        }
        return up[0][u];
    }

    static boolean isOnPath(int p, int u, int v) {
        int luv = lca(u, v);
        return (lca(u, p) == p && lca(p, v) == luv) || (lca(v, p) == p && lca(p, u) == luv);
    }

    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);
        n = sc.nextInt();
        m = sc.nextInt();

        for (int i = 1; i <= n; i++) adj[i] = new ArrayList<>();

        for (int i = 0; i < n - 1; i++) {
            int u = sc.nextInt();
            int v = sc.nextInt();
            adj[u].add(v);
            adj[v].add(u);
        }

        dfs(1, 0, 1);
        up[0][0] = 0;

        for (int i = 1; i < LOGN; i++) {
            for (int j = 1; j <= n; j++) {
                up[i][j] = up[i - 1][up[i - 1][j]];
            }
        }

        for (int i = 0; i < m; i++) {
            int a = sc.nextInt();
            int b = sc.nextInt();
            int c = sc.nextInt();
            int d = sc.nextInt();
            int l_ab = lca(a, b);
            int l_cd = lca(c, d);
            if (isOnPath(l_ab, c, d) || isOnPath(l_cd, a, b)) {
                System.out.println("Yes");
            } else {
                System.out.println("No");
            }
        }
    }
}
import sys

sys.setrecursionlimit(100005)

def solve():
    n, m = map(int, input().split())
    
    MAXN = n + 1
    LOGN = (n).bit_length()
    
    adj = [[] for _ in range(MAXN)]
    for _ in range(n - 1):
        u, v = map(int, input().split())
        adj[u].append(v)
        adj[v].append(u)
        
    depth = [0] * MAXN
    up = [[0] * MAXN for _ in range(LOGN)]

    def dfs(u, p, d):
        depth[u] = d
        up[0][u] = p
        for v in adj[u]:
            if v != p:
                dfs(v, u, d + 1)
    
    dfs(1, 0, 1)

    for i in range(1, LOGN):
        for j in range(1, n + 1):
            up[i][j] = up[i - 1][up[i - 1][j]]
            
    def lca(u, v):
        if depth[u] < depth[v]:
            u, v = v, u
        
        for i in range(LOGN - 1, -1, -1):
            if depth[u] - (1 << i) >= depth[v]:
                u = up[i][u]
        
        if u == v:
            return u
            
        for i in range(LOGN - 1, -1, -1):
            if up[i][u] != up[i][v]:
                u = up[i][u]
                v = up[i][v]
                
        return up[0][u]

    def is_on_path(p, u, v):
        luv = lca(u, v)
        return (lca(u, p) == p and lca(p, v) == luv) or \
               (lca(v, p) == p and lca(p, u) == luv)

    results = []
    for _ in range(m):
        a, b, c, d = map(int, input().split())
        l_ab = lca(a, b)
        l_cd = lca(c, d)
        if is_on_path(l_ab, c, d) or is_on_path(l_cd, a, b):
            results.append("Yes")
        else:
            results.append("No")
    
    print("\n".join(results))

solve()

算法及复杂度

  • 算法:最近公共祖先(LCA) + 倍增法
  • 时间复杂度。预处理阶段需要 ,后续的 次查询每次需要
  • 空间复杂度,主要用于存储倍增表。