题解:BISHI124 【模板】最近公共祖先(LCA)

题目链接

【模板】最近公共祖先(LCA)

题目描述

给定以 为根的 节点树, 次询问两点的最近公共祖先(LCA)。

解题思路

二进制倍增(倍增跳父):

  • 预处理每个点的深度 与第 个祖先
  • 回答时先将两点提到同一深度,再从高位到低位同时跳,最后取其父即为 LCA。

代码

#include <bits/stdc++.h>
using namespace std;

int main(){
    ios::sync_with_stdio(false);
    cin.tie(nullptr);

    int n, q, r; if(!(cin >> n >> q >> r)) return 0;
    vector<vector<int>> g(n+1);
    for(int i=0;i<n-1;i++){ int u,v; cin >> u >> v; g[u].push_back(v); g[v].push_back(u); }
    int LOG = 1; while((1<<LOG) <= n) ++LOG;
    vector<int> dep(n+1, -1);
    vector<vector<int>> up(LOG, vector<int>(n+1, 0));
    queue<int> qu; dep[r]=0; up[0][r]=0; qu.push(r);
    while(!qu.empty()){
        int u = qu.front(); qu.pop();
        for(int v: g[u]) if(dep[v]==-1){ dep[v]=dep[u]+1; up[0][v]=u; qu.push(v);}    }
    for(int k=1;k<LOG;k++) for(int v=1;v<=n;v++) up[k][v] = up[k-1][ up[k-1][v] ];
    auto lift = [&](int x, int d){ for(int k=0; d; k++, d>>=1) if(d&1) x = up[k][x]; return x; };
    auto lca = [&](int a, int b){
        if(dep[a] < dep[b]) swap(a,b);
        a = lift(a, dep[a]-dep[b]);
        if(a==b) return a;
        for(int k=LOG-1;k>=0;k--) if(up[k][a]!=up[k][b]){ a=up[k][a]; b=up[k][b]; }
        return up[0][a];
    };
    while(q--){ int a,b; cin >> a >> b; cout << lca(a,b) << '\n'; }
    return 0;
}
import java.io.*;
import java.util.*;

public class Main {
    static class FastScanner {
        private final InputStream in; private final byte[] buf = new byte[1<<16];
        private int p=0,l=0; FastScanner(InputStream is){in=is;}
        private int read() throws IOException { if(p>=l){ l=in.read(buf); p=0; if(l<=0) return -1; } return buf[p++]; }
        int nextInt() throws IOException { int c; int s=1,x=0; do{c=read();}while(c<=32); if(c=='-'){s=-1;c=read();} while(c>32){ x=x*10+(c-'0'); c=read(); } return x*s; }
    }

    public static void main(String[] args) throws Exception {
        FastScanner fs = new FastScanner(System.in);
        int n = fs.nextInt();
        int q = fs.nextInt();
        int r = fs.nextInt();
        ArrayList<Integer>[] g = new ArrayList[n+1];
        for(int i=0;i<=n;i++) g[i]=new ArrayList<>();
        for(int i=0;i<n-1;i++){ int u=fs.nextInt(), v=fs.nextInt(); g[u].add(v); g[v].add(u); }
        int LOG = 1; while((1<<LOG) <= n) ++LOG;
        int[] dep = new int[n+1]; Arrays.fill(dep, -1);
        int[][] up = new int[LOG][n+1];
        ArrayDeque<Integer> dq = new ArrayDeque<>(); dep[r]=0; up[0][r]=0; dq.add(r);
        while(!dq.isEmpty()){
            int u = dq.poll();
            for(int v: g[u]) if(dep[v]==-1){ dep[v]=dep[u]+1; up[0][v]=u; dq.add(v);}    }
        for(int k=1;k<LOG;k++) for(int v=1;v<=n;v++) up[k][v] = up[k-1][ up[k-1][v] ];
        StringBuilder out = new StringBuilder();
        while(q-->0){
            int a = fs.nextInt(), b = fs.nextInt();
            if(dep[a] < dep[b]){ int t=a; a=b; b=t; }
            int d = dep[a]-dep[b];
            for(int k=0; d>0; k++, d>>=1) if((d&1)==1) a = up[k][a];
            if(a!=b){
                for(int k=LOG-1;k>=0;k--) if(up[k][a]!=up[k][b]){ a=up[k][a]; b=up[k][b]; }
                a = up[0][a];
            }
            out.append(a).append('\n');
        }
        System.out.print(out.toString());
    }
}
import sys
from collections import deque

data = sys.stdin.buffer.read().split()
it = iter(data)
n = int(next(it)); q = int(next(it)); r = int(next(it))
g = [[] for _ in range(n+1)]
for _ in range(n-1):
    u = int(next(it)); v = int(next(it))
    g[u].append(v); g[v].append(u)

LOG = 1
while (1 << LOG) <= n:
    LOG += 1
dep = [-1]*(n+1)
up = [[0]*(n+1) for _ in range(LOG)]
dep[r] = 0
up[0][r] = 0
dq = deque([r])
while dq:
    u = dq.popleft()
    for v in g[u]:
        if dep[v] == -1:
            dep[v] = dep[u] + 1
            up[0][v] = u
            dq.append(v)
for k in range(1, LOG):
    row_prev = up[k-1]
    row = up[k]
    for v in range(1, n+1):
        row[v] = row_prev[row_prev[v]]

out_lines = []
for _ in range(q):
    a = int(next(it)); b = int(next(it))
    if dep[a] < dep[b]:
        a, b = b, a
    d = dep[a] - dep[b]
    k = 0
    while d:
        if d & 1:
            a = up[k][a]
        d >>= 1
        k += 1
    if a != b:
        for k in range(LOG-1, -1, -1):
            if up[k][a] != up[k][b]:
                a = up[k][a]
                b = up[k][b]
        a = up[0][a]
    out_lines.append(str(a))
sys.stdout.write('\n'.join(out_lines))

算法及复杂度

  • 算法:二进制倍增 LCA(预处理 表与深度)
  • 时间复杂度:预处理 ,查询
  • 空间复杂度: