一、基本概念:
给定一棵有根树,若节点z既是节点x的祖先,也是节点y的祖先,则称z是x,y的公共祖先。在x,y的所有公共祖先中,深度最大的一个称为x,y的最近公共祖先,记为LCA(x,y)。
LCA(x,y)是x到根的路径与y到根的路径的交汇点。它也是x与y之间的路径上深度最小的节点。求最近公共祖先的方法通常有五种。
在此介绍三种,另外两种还没学。
二、向上标记法:
从x向上走到根节点,并标记所有经过的节点。
从y向上走到根节点,当第一次遇到已标记的节点时,就找到了LCA(x,y)。
对于每个询问,向上标记法的时间复杂度最坏为O(n)。
三、树上倍增法:
树上倍增法是一个很重要的算法。除了LCA之外,它在很多问题中都有广泛应用。设f(x,k)表示x的2k辈祖先,即从x向根节点走2k步到达的节点。特别的,若该节点不存在,则令f(x,k)=0。f(x,0)就是x的父节点。除此之外,1≤k≤logn ,f(x,k)=f(f(x,k-1),k-1)。
这类似于一个动态规划的过程,阶段就是节点的深度。因此,我们可以对树进行广度优先遍历,按照层次顺序,在节点入队之前,计算它在F数组中相应的值。
以上是预处理部分,时间复杂度为O(nlogn),之后可以多次对不同的x,y计算LCA,每次询问的时间复杂度为O(logn)。
基于f数组计算LCA(x,y)分为以下几步:
(1)----设d(x)表示x的深度。不妨设d(x)≥d(y)。(否则可以交换x,y)。
(2)----用二进制拆分思想,把x向上调整到于y同一深度。
具体来说,就是依次尝试从x向上走k=2logn,……21,20步,检查到达的节点是否比y深。在每次检查中,若是,则令x=f(x,k)。
(3)----若此时x==y,说明已经找到了LCA,LCA就是y。
(4)----用二进制拆分思想,把x,y同时向上调整,并保持深度一致且二者不相汇。
具体来说,就是依次尝试把x,y同时向上走k=2logn,……21,20步,在每次尝试中,
若f(x,k)≠f(y,k)(即仍未相会),则令x=f(x,k),y=f(y,k)。
(5)----此时x,y必定只差一步就相会了,他们的父节点f(x,0)就是LCA。
以HDOJ2586 how far away为例:
时间复杂度为:O((n+m)logn)。
#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cstring>
#include<string>
#include<cmath>
#include<queue>
#define ll long long
using namespace std;
const int maxn=50010;
int f[maxn][20],d[maxn],dis[maxn];
int ver[maxn*2],nt[maxn*2],edge[maxn*2],head[maxn];
int T,n,m,tot,t;
queue<int>q;
void add(int x,int y,int z)
{
ver[++tot]=y,edge[tot]=z;
nt[tot]=head[x],head[x]=tot;
}
void bfs(void)
{
q.push(1);
d[1]=1;
while(q.size())
{
int x=q.front();
q.pop();
for(int i=head[x];i;i=nt[i])
{
int y=ver[i];
if(d[y]) continue;
d[y]=d[x]+1;
dis[y]=dis[x]+edge[i];
f[y][0]=x;
for(int j=1;j<=t;j++)
f[y][j]=f[f[y][j-1]][j-1];
q.push(y);
}
}
}
int lca(int x,int y)
{
if(d[x]>d[y]) swap(x,y);
for(int i=t;i>=0;i--)
{
if(d[f[y][i]]>=d[x]) y=f[y][i];
}
if(x==y) return x;
for(int i=t;i>=0;i--)
{
if(f[x][i]!=f[y][i]) x=f[x][i],y=f[y][i];
}
return f[x][0];
}
int main(void)
{
cin>>T;
while(T--)
{
cin>>n>>m;
t=(int)(log(n)/log(2))+1;
for(int i=1;i<=n;i++) dis[i]=head[i]=d[i]=0;
tot=0;
int x,y,z;
for(int i=1;i<n;i++)
{
scanf("%d%d%d",&x,&y,&z);
add(x,y,z);
add(y,x,z);
}
bfs();
for(int i=1;i<=m;i++)
{
scanf("%d%d",&x,&y);
printf("%d\n",dis[x]+dis[y]-2*dis[lca(x,y)]);
}
}
return 0;
}
四、LCA的Tarjan算法:
Tarjan算法本质上是使用并查集对“向上标记法”的优化。它是一个离线算法,需要把m个询问一次性读入,统一计算,最后统一输出。时间复杂度O(n+m)。
在深度优先遍历的任意时刻,树中节点分为三类:
(1)----已经访问完毕并且回溯的节点。在这些节点上标记一个整数2。
(2)----已经开始递归,但尚未回溯的节点。这些节点就是当前正在访问的节点x以及x的祖先。在这些节点上标记一个整数1。
(3)----尚未访问的节点。这些节点没有标记。
对于正在访问的节点x,它到根节点的路径已经标记为1。若y是已经访问完毕并且回溯的节点,则LCA(x,y)就是从y向上走到根,第一个遇到的标记为1的节点。
可以利用并查集进行优化,当一个节点获得整数2的标记时,把它所在的集合合并到它的父节点所在的集合中(合并时它的父节点标记一定为1,且单独构成一个集合)。
这相当于每个完成回溯的节点都有一个指针指向它的父节点,只需查询y所在集合的代表元素(并查集的get操作),就等价于从y向上一直走到一个开始递归但尚未回溯的节点(具有标记1),即LCA(x,y)。
此时扫描与x相关的所有询问,若询问当中的另一个点y的标记为2,就知道了该询问的回答应该是y在并查集中的代表元素(get(y)函数的结果)。
同样以how far away为例:
时间复杂度为O(n+m)。
#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cstring>
#include<string>
#include<cmath>
#include<queue>
#include<vector>
#define ll long long
using namespace std;
const int maxn=50010;
int ver[maxn*2],nt[maxn*2],edge[maxn*2],head[maxn];
int fa[maxn],d[maxn],v[maxn],lca[maxn],ans[maxn];
vector<int>query[maxn],query_id[maxn];
int T,n,m,tot,t;
void add(int x,int y,int z)
{
ver[++tot]=y,edge[tot]=z;
nt[tot]=head[x],head[x]=tot;
}
void add_query(int x,int y,int id)
{
query[x].push_back(y),query_id[x].push_back(id);
query[y].push_back(x),query_id[y].push_back(id);
}
int _get(int x)
{
if(x!=fa[x])
fa[x]=_get(fa[x]);
return fa[x];
}
void tarjan(int x)
{
v[x]=1;
for(int i=head[x];i;i=nt[i])
{
int y=ver[i];
if(v[y]) continue;
d[y]=d[x]+edge[i];
tarjan(y);
fa[y]=x;
}
for(int i=0;i<query[x].size();i++)
{
int y=query[x][i],id=query_id[x][i];
if(v[y]==2)
{
int lca=_get(y);
ans[id]=d[x]+d[y]-2*d[lca];
//ans[id]=min(ans[id],d[x]+d[y]-2*d[lca]);
}
}
v[x]=2;
}
int main(void)
{
cin>>T;
while(T--)
{
cin>>n>>m;
for(int i=1;i<=n;i++)
{
head[i]=0;fa[i]=i;v[i]=0;
query[i].clear();
query_id[i].clear();
}
tot=0;
int x,y,z;
for(int i=1;i<n;i++)
{
scanf("%d%d%d",&x,&y,&z);
add(x,y,z);
add(y,x,z);
}
for(int i=1;i<=m;i++)
{
scanf("%d%d",&x,&y);
if(x==y) ans[i]=0;
else
{
add_query(x,y,i);
ans[i]=1<<30;
}
}
tarjan(1);
for(int i=1;i<=m;i++)
printf("%d\n",ans[i]);
}
return 0;
}