树上倍增比起树链剖分代码短,容易查错,时空优,但是广度不如树链剖分.
具体实现
首先开一个n×logn的数组,比如fa[n][logn],其中fa[i][j]表示i节点的第2^j个父亲是谁。
那么就有: fa[i][j]=fa[fa[i][j-1]][j-1]
用文字叙述为:i的第2^j 个父亲 是i的第2^(j-1) 个父亲的第2^(j-1)个父亲。
下面是求i的第k个父亲的代码段:

int father(int i,int k)
{
   
    for(int x=0;x<=int(log2(k));x++)
        if((1<<x)&k)    //(1<<x)&k可以判断k的二进制表示中,第(x-1)位上是否为1
            i=fa[i][x];     //把i往上提
    return i;
}

我们可以通过一次dfs处理出fa数组:(dep[i]表示i的深度,这个可以一起处理出来,以后要用)

void dfs(int x)
{
   
    for(int i=1;i<=max0;i++)
        if(fa[x][i-1])   //在dfs(x)之前,x的父辈们的fa数组都已经计算完毕,所以可以用来计算x
            fa[x][i]=fa[fa[x][i-1]][i-1];
        else break;    //如果x已经没有第2^(i-1)个父亲了,那么也不会有更远的父亲,直接break
    for(/*每一个与x相连的节点i*/)
        if(i!=fa[x][0])     //如果i不是x的父亲就是x的儿子
        {
   
            fa[i][0]=x;       //记录儿子的第一个父亲是x
            dep[i]=dep[x]+1;      //处理深度
            dfs(i);
        }
}
 

树上倍增常用来求最近公共祖先

int LCA(int u,int v)
{
   
    if(dep[u]<dep[v])swap(u,v);  //我们默认u的深度一开始大于v,那么如果u的深度小就交换u和v
    int delta=dep[u]-dep[v];    //计算深度差
    for(int x=0;x<=max0;x++)    //此循环用于提到深度相同。
        if((1<<x)&delta)
            u=fa[u][x];
    if(u==v)return u;
    for(int x=max0;x>=0;x--)     //注意!此处循环必须是从大到小!因为我们应该越提越“精确”,
        if(fa[u][x]!=fa[v][x])   //如果从小到大的话就有可能无法提到正确位置,自己可以多想一下
        {
   
            u=fa[u][x];
            v=fa[v][x];
        }
    return fa[u][0];    //此时u、v的第一个父亲就是LCA。
}
 

倍增还可以有很多变化,这让倍增法可以优更多的变化。比如用data[i][j]记录i到他的第2^j 个父亲的路径长度,就可以边求LCA边求出两点距离,因为data[i][j]满足倍增的递推式:data[i][j]=data[i][j-1]+data[fa[i][j-1]][j-1]。或者用maxlen[i][j]记录i到第2^j个父亲的路径上最长边的边权,它满足maxlen[i][j]=max{maxlen[i][j-1],maxlen[fa[i][j-1]][j-1]},这样就可以快速求出两点路径上最长边的边权……

最后附上一道LCA的模板题代码

#include<cstdio>
#include<vector>
#include<cstring>
using namespace std;
const int maxn=10010;
vector<int>	v[maxn<<1];
int deep[maxn],fa[maxn][22];
int sum[maxn];
void dfs(int x){
   
	for(int i=1;i<=21;i++){
   
		if(fa[x][i-1])	fa[x][i]=fa[fa[x][i-1]][i-1];
		else break;
	}
	for(int i=0;i<v[x].size();i++){
   
		int t=v[x][i];
		if(t!=fa[x][0]){
   
			fa[t][0]=x;
			deep[t]=deep[x]+1;
			dfs(t);
		}
	}
}
int LCA(int u,int v){
   
	if(deep[u]<deep[v])	swap(u,v);
	int del=deep[u]-deep[v];
	for(int i=0;i<=21;i++)   
        if((1<<i)&del)
            u=fa[u][i];
	if(u==v)	return u;
	for(int i=21;i>=0;i--){
   
		if(fa[u][i]!=fa[v][i]){
   
			u=fa[u][i];
			v=fa[v][i];
		}
	}
	return fa[u][0];
}
int main(){
   
	int t;
	scanf("%d",&t);
	while(t--){
   
		memset(v,0,sizeof(v));
		memset(deep,0,sizeof(deep));
		memset(fa,0,sizeof(fa));
		memset(sum,0,sizeof(sum));
		int n;
		scanf("%d",&n);
		int x,y;
		for(int i=1;i<n;i++){
   
			scanf("%d%d",&x,&y);
			v[x].push_back(y);
			v[y].push_back(x);
			sum[y]++;
		}
		scanf("%d%d",&x,&y);
		for(int i=1;i<=n;i++)
			if(!sum[i])
				dfs(i);
		printf("%d\n",LCA(x,y));
	}
	return 0;
}