题目链接

PEEK58 【模板】最近公共祖先(LCA)

题目描述

给定一棵由 个节点组成的、以 为根的多叉树。有 次询问,每次给定两个节点 ,你需要求出这两个节点的最近公共祖先(LCA)。

最近公共祖先(LCA):指在树中离两个节点最近的、且同时是这两个节点的祖先的节点。

解题思路

本题是最近公共祖先(LCA)的模板题。由于节点和查询数量都很大,对于每次查询,从一个节点暴力向上走到根,并记录路径,再从另一个节点向上走直到碰到已记录路径的节点,这种朴素算法的时间复杂度为 ,会超时。

解决此问题的标准高效算法是倍增法(Binary Lifting)。该算法分为预处理查询两个阶段。

  1. 预处理

    • 深度优先搜索(DFS):首先,我们从根节点 开始对整棵树进行一次 DFS。在遍历过程中,我们可以确定每个节点的深度 depth[u] 和它的直接父节点 up[0][u](即 级祖先)。
    • 构建倍增表:我们创建一个二维数组 up[p][u],用于存储节点 的第 个祖先。根据 up[0][u],我们可以用动态规划的思想来填充整个表。递推关系是: 的第 个祖先,就是 的第 个祖先的第 个祖先。即 up[p][u] = up[p-1][up[p-1][u]]。 预处理阶段的总时间复杂度为
  2. 查询

    对于每次查询 ,我们利用预处理好的倍增表来快速定位LCA。

    • 将节点提到同一深度:首先,比较 的深度,将深度较大的节点(假设是 )向上提升,直到它的深度和 相同。这个提升的过程可以用倍增来加速,总能在 时间内完成。
    • 判断是否重合:如果此时 已经是同一个节点,那么它就是LCA,查询结束。
    • 同步向上跳跃:如果 不相同,我们就让它们同步地、一步一步地向上跳。为了效率,我们从大到小尝试跳跃的步长()。只要跳跃后两者没有相遇(即 up[p][u] != up[p][v]),我们就执行这次跳跃。这样可以保证我们最大限度地接近LCA,而又不会越过它。
    • 找到LCA:当这个跳跃过程结束后, 所在的节点就是LCA的两个直接子节点。因此,它们的父节点 up[0][u](或 up[0][v])就是所求的LCA。

    每次查询的时间复杂度为

代码

#include <iostream>
#include <vector>
#include <algorithm>
#include <cmath>

using namespace std;

const int MAXN = 500005;
const int LOGN = 20;

vector<int> adj[MAXN];
int depth[MAXN];
int up[LOGN][MAXN];
int n, m, s;

void dfs(int u, int p, int d) {
    depth[u] = d;
    up[0][u] = p;
    for (int v : adj[u]) {
        if (v != p) {
            dfs(v, u, d + 1);
        }
    }
}

int lca(int u, int v) {
    if (depth[u] < depth[v]) {
        swap(u, v);
    }

    int diff = depth[u] - depth[v];
    for (int i = LOGN - 1; i >= 0; --i) {
        if ((diff >> i) & 1) {
            u = up[i][u];
        }
    }

    if (u == v) {
        return u;
    }

    for (int i = LOGN - 1; i >= 0; --i) {
        if (up[i][u] != up[i][v]) {
            u = up[i][u];
            v = up[i][v];
        }
    }

    return up[0][u];
}

int main() {
    cin >> n >> m >> s;
    for (int i = 0; i < n - 1; ++i) {
        int u, v;
        cin >> u >> v;
        adj[u].push_back(v);
        adj[v].push_back(u);
    }

    dfs(s, 0, 1);

    for (int i = 1; i < LOGN; ++i) {
        for (int j = 1; j <= n; ++j) {
            up[i][j] = up[i - 1][up[i - 1][j]];
        }
    }

    for (int i = 0; i < m; ++i) {
        int u, v;
        cin >> u >> v;
        cout << lca(u, v) << "\n";
    }

    return 0;
}
import java.util.ArrayList;
import java.util.Scanner;

public class Main {
    static final int MAXN = 500005;
    static final int LOGN = 20;
    static ArrayList<Integer>[] adj = new ArrayList[MAXN];
    static int[] depth = new int[MAXN];
    static int[][] up = new int[LOGN][MAXN];
    static int n, m, s;

    static void dfs(int u, int p, int d) {
        depth[u] = d;
        up[0][u] = p;
        for (int v : adj[u]) {
            if (v != p) {
                dfs(v, u, d + 1);
            }
        }
    }

    static int lca(int u, int v) {
        if (depth[u] < depth[v]) {
            int temp = u;
            u = v;
            v = temp;
        }

        int diff = depth[u] - depth[v];
        for (int i = LOGN - 1; i >= 0; i--) {
            if (((diff >> i) & 1) == 1) {
                u = up[i][u];
            }
        }

        if (u == v) {
            return u;
        }

        for (int i = LOGN - 1; i >= 0; i--) {
            if (up[i][u] != up[i][v]) {
                u = up[i][u];
                v = up[i][v];
            }
        }

        return up[0][u];
    }

    public static void main(String[] args) {
        Scanner sc = new Scanner(System.in);

        n = sc.nextInt();
        m = sc.nextInt();
        s = sc.nextInt();

        for (int i = 1; i <= n; i++) {
            adj[i] = new ArrayList<>();
        }

        for (int i = 0; i < n - 1; i++) {
            int u = sc.nextInt();
            int v = sc.nextInt();
            adj[u].add(v);
            adj[v].add(u);
        }

        dfs(s, 0, 1);

        for (int i = 1; i < LOGN; i++) {
            for (int j = 1; j <= n; j++) {
                up[i][j] = up[i - 1][up[i - 1][j]];
            }
        }

        for (int i = 0; i < m; i++) {
            int u = sc.nextInt();
            int v = sc.nextInt();
            System.out.println(lca(u, v));
        }
    }
}
import sys

# 增大递归深度限制以处理深度较大的树
sys.setrecursionlimit(500005)

def solve():
    n, m, s = map(int, input().split())
    
    MAXN = n + 1
    LOGN = (n).bit_length()
    
    adj = [[] for _ in range(MAXN)]
    for _ in range(n - 1):
        u, v = map(int, input().split())
        adj[u].append(v)
        adj[v].append(u)
        
    depth = [0] * MAXN
    up = [[0] * MAXN for _ in range(LOGN)]

    def dfs(u, p, d):
        depth[u] = d
        up[0][u] = p
        for v in adj[u]:
            if v != p:
                dfs(v, u, d + 1)
    
    dfs(s, 0, 1)

    for i in range(1, LOGN):
        for j in range(1, n + 1):
            up[i][j] = up[i - 1][up[i - 1][j]]
            
    def lca(u, v):
        if depth[u] < depth[v]:
            u, v = v, u
        
        diff = depth[u] - depth[v]
        for i in range(LOGN - 1, -1, -1):
            if (diff >> i) & 1:
                u = up[i][u]
        
        if u == v:
            return u
            
        for i in range(LOGN - 1, -1, -1):
            if up[i][u] != up[i][v]:
                u = up[i][u]
                v = up[i][v]
                
        return up[0][u]

    results = []
    for _ in range(m):
        u, v = map(int, input().split())
        results.append(str(lca(u, v)))
    
    print("\n".join(results))

solve()

算法及复杂度

  • 算法:倍增法 (Binary Lifting)
  • 时间复杂度。预处理阶段需要 ,后续的 次查询每次需要
  • 空间复杂度,主要用于存储倍增表 up