0-base版本
#include <iostream>
#include <vector>
#include <functional>
int main() {
std::ios::sync_with_stdio(false);
std::cin.tie(nullptr);
std::cout.tie(nullptr);
int n,m,rt;
std::cin >> n >> m >> rt;
rt--;
std::vector<std::vector<int>> adj(n);
for(int i = 0; i < n-1; i++){
int x,y;
std::cin >> x >> y;
x--,y--;
adj[x].push_back(y);
adj[y].push_back(x);
}
std::vector f(20,std::vector<int>(n, -1));
std::vector<int> par(n, -1),dep(n);
dep[rt] = 1;
par[rt] = rt;
std::function<void(int)> dfs = [&](int u){
f[0][u] = par[u];
for(int i = 1; i < 20; i++){
f[i][u] = f[i-1][f[i-1][u]];
}
for(const auto& v : adj[u]){
if(v == par[u]){
continue;
}
par[v] = u;
dep[v] = dep[u]+1;
dfs(v);
}
};
dfs(rt);
auto lca = [&](int u,int v){
if(dep[u] < dep[v]){
std::swap(u,v);
}
for(int i = 19; i >= 0; i--){
if(~f[i][u] and dep[f[i][u]] >= dep[v]){
u = f[i][u];
}
}
if(u == v){
return u;
}
for(int i = 19; i >= 0; i--){
if(~f[i][u] and ~f[i][v] and f[i][u] != f[i][v]){
u = f[i][u];
v = f[i][v];
}
}
return f[0][u];
};
for(int mi = 0; mi < m; mi++){
int u,v;
std::cin >> u >> v;
u--,v--;
std::cout << lca(u,v)+1 << "\n";
}
}
1-base版本
#include <iostream>
#include <vector>
#include <functional>
int main() {
std::ios::sync_with_stdio(false);
std::cin.tie(nullptr);
std::cout.tie(nullptr);
int n,m,rt;
std::cin >> n >> m >> rt;
std::vector<std::vector<int>> adj(n+1);
for(int i = 1; i < n; i++){
int x,y;
std::cin >> x >> y;
adj[x].push_back(y);
adj[y].push_back(x);
}
std::vector f(20,std::vector<int>(n+1));
std::vector<int> par(n+1),dep(n+1);
dep[rt] = 1;
std::function<void(int)> dfs = [&](int u){
f[0][u] = par[u];
for(int i = 1; i < 20; i++){
f[i][u] = f[i-1][f[i-1][u]];
}
for(const auto& v : adj[u]){
if(v == par[u]){
continue;
}
par[v] = u;
dep[v] = dep[u]+1;
dfs(v);
}
};
dfs(rt);
auto lca = [&](int u,int v){
if(dep[u] < dep[v]){
std::swap(u,v);
}
for(int i = 19; i >= 0; i--){
if(dep[f[i][u]] >= dep[v]){
u = f[i][u];
}
}
if(u == v){
return u;
}
for(int i = 19; i >= 0; i--){
if(f[i][u] != f[i][v]){
u = f[i][u];
v = f[i][v];
}
}
return f[0][u];
};
for(int mi = 0; mi < m; mi++){
int u,v;
std::cin >> u >> v;
std::cout << lca(u,v) << "\n";
}
}

京公网安备 11010502036488号