LCA问题
求树上的最近公共祖先
1 倍增法
倍增法,算法正如它的名字一样,比较好理解,预处理出parent[k][u] ,表示从u节点向上2^k步到大的节点编号,递推的时候
如果需要求距离
POJ 1330
// POJ 1330
const int maxn = 1e4+100;
const int maxlogv = 14;
vector<int> G[maxn];
int root;
int parent[maxlogv][maxn];
int depth[maxn];
void dfs(int v,int p,int d){
parent[0][v] = p;
depth[v] = d;
for(int i = 0;i < G[v].size(); ++i){
if(G[v][i] != p){
dfs(G[v][i],v,d+1);
}
}
}
void init(int V){
dfs(root,-1,0);
for(int k = 0;k+1 < maxlogv; ++k){
for(int v = 0; v < V; ++v){
if(parent[k][v] < 0) parent[k+1][v] = -1;
else parent[k+1][v] = parent[k][parent[k][v]];
}
}
}
int lca(int u,int v){
if(depth[u] > depth[v]) swap(u,v);
for(int k = 0;k < maxlogv; ++k){
if(((depth[v] - depth[u]) >> k)& 1){
v = parent[k][v];
}
}
// cout<<depth[u]<<" "<<depth[v]<<endl;
// cout<<v+1<<" "<<u+1<<endl;
if(u == v) return u;
for(int k = maxlogv-1; k >= 0; --k){
if(parent[k][u] != parent[k][v]){
u = parent[k][u];
v = parent[k][v];
}
}
return parent[0][u];
}
bool OUT[maxn];
int main(void)
{
int T;
scanf("%d",&T);
while(T--){
int n;
rep(i,0,n) G[i].clear();
me(OUT);
scanf("%d",&n);
// cout<<n<<endl;
rep(i,1,n) {
int u,v;
scanf("%d %d",&u,&v);
u--,v--;
G[u].push_back(v);
OUT[v] = 1;
// G[v].push_back(u);
}
rep(i,0,n) if(!OUT[i]){
root = i;
break;
}
// cout<<root<<endl; // root = 0;
init(n);
int u,v;
scanf("%d %d",&u,&v);
u--,v--;
printf("%d\n",lca(u,v)+1);
}
return 0;
}
HDU2586
HDU2586 求树上的最短路,LCA+倍增
typedef pair<int,int> P;
const int maxn = 4e4+100;
const int maxlogv = 16;
vector<P> G[maxn];
int root;
int parent[maxlogv][maxn];
int depth[maxn];
int dis[maxlogv][maxn];
void dfs(int v,int p,int d,int D){
parent[0][v] = p;
dis[0][v] = D;
depth[v] = d;
for(int i = 0;i < G[v].size(); ++i){
P &to = G[v][i];
if(to.FI != p){
dfs(to.FI,v,d+1,to.SE);
}
}
}
void init(int V){
memset(dis,0,sizeof(dis));
dfs(root,-1,0,0);
for(int k = 0;k+1 < maxlogv; ++k){
for(int v = 0; v < V; ++v){
if(parent[k][v] < 0) {
parent[k+1][v] = -1;
// dis[k+1][v] = dis[k][v];
}
else {
dis[k+1][v] = dis[k][v] + dis[k][parent[k][v]];
parent[k+1][v] = parent[k][parent[k][v]];
}
}
}
}
int lca(int u,int v){
int ans = 0;
if(depth[u] > depth[v]) swap(u,v);
// cout<<depth[u]<<" "<<depth[v]<<endl;
for(int k = 0;k < maxlogv; ++k){
if(((depth[v] - depth[u]) >> k)& 1){
ans += dis[k][v];
v = parent[k][v];
}
}
// cout<<v+1<<" "<<u+1<<endl;
if(u == v) return ans;
// cout<<u<<" "<<v<<endl;
for(int k = maxlogv-1; k >= 0; --k){
if(parent[k][u] != parent[k][v]){
ans += dis[k][u];
ans += dis[k][v];
u = parent[k][u];
v = parent[k][v];
}
}
if(u != v) ans += dis[0][u]+dis[0][v];
return ans;
}
int main(void)
{
int T;
scanf("%d",&T);
// cout<<T<<endl;
while(T--){
int n,m;
scanf("%d %d",&n,&m);
rep(i,0,n) G[i].clear();
int u,v,w;
rep(i,1,n) {
scanf("%d %d %d",&u,&v,&w);
u--,v--;
G[u].push_back(P(v,w));
G[v].push_back(P(u,w));
}
root = 0;
init(n);
// cout<<dis[0][1]<<endl;
// cout<<dis[0][2]<<endl;
while(m--){
scanf("%d %d",&u,&v);
u--,v--;
// cout<<u<<" "<<v<<endl;
// cout<<depth[u]<<" "<<depth[v]<<endl;
printf("%d\n",lca(u,v));
}
}
return 0;
}