树链剖分


发现最近几天可以出专题了。。。近几天搞板子题真的是逼着我写一些东西。。。那么我们来搞一搞树链剖分

原理

树链剖分,实际上就是一种把树结构映射到一颗线段树结构上的算法,常用于搞各种树上的两点路径查询及修改的问题,但树的形态不能改变,否则要改用LCT,然而我还不会hehe

我们记录如下的一些东西:

  • top 数组,用于记录每个点所在的链的顶端的节点

  • dfs 数组,用于记录每个点在 DFS 序列中的位置

  • size 数组,记录每个节点的子树大小

  • son 数组,记录每个点的重儿子

  • fa 数组,记录每个点的父亲

  • dep 数组,记录每个节点的深度

  • i2x 数组,就我个人而言,这是我的习惯,是 dfs 的一个反映射

    现在说一下最常用的树剖方法:轻重树链剖分法,每次我们找一个节点的儿子中子树大小最大的那一个,然后把它与原节点归到同一个链子上,其余的节点再分别作为其他链子的顶端接着搞

其实挺简单的,就是代码量大,调试困难,容易写错,浪费时间罢了

代码

恩,下面这个是一个带两点路径查询以及两点路径修改的树剖,查询时查找SUM和MAX

#include<cstdlib>
#include<cstdio>
#include<algorithm>
#include<vector>
#define maxn 100005
#define time t
using namespace std;
int t=0,n,m,qx,qy,qd,dep[maxn],size[maxn],fa[maxn],id[maxn],id2x[maxn],son[maxn],top[maxn],line[maxn],sum[maxn*2],maxnum[maxn*2];
vector<int> geo[maxn];
void DFS1(int u){
    dep[u]=dep[fa[u]]+1;
    size[u]=1;
    son[u]=0;
    for(int i=0;i<geo[u].size();i++){
        int op=geo[u][i];
        if(op==fa[u])continue;
        fa[op]=u;
        DFS1(op);
        size[u]+=size[op];
        if(size[son[u]]<size[op]){
            son[u]=op;
        }
    }
}
void DFS2(int x,int tp){
    top[x]=tp;
    id[x]=++time;
    id2x[time]=x;
    if(son[x])DFS2(son[x],tp);
    for(int i=0;i<geo[x].size();i++){
        int op=geo[x][i];
        if(op==fa[x]||op==son[x])continue;
        DFS2(op,op);
    }
}
void build(int l,int r,int o){
    if(l==r){
        sum[o]=line[id2x[l]],maxnum[o]=line[id2x[l]];
        return;
    }
    int mid=((r-l)>>1)+l;
    build(l,mid,o<<1);
    build(mid+1,r,(o<<1)+1);
    sum[o]=sum[o<<1]+sum[(o<<1)+1];
    maxnum[o]=max(maxnum[o<<1],maxnum[(o<<1)+1]);
    return;
}
void UPDATE(int l,int r,int o){
    if(qx<=l&&r<=qy){
        sum[o]=qd,maxnum[o]=qd;
        return;
    }
    int mid=((r-l)>>1)+l;
    if(qx<=mid)UPDATE(l,mid,o<<1);
    if(mid<qy)UPDATE(mid+1,r,(o<<1)+1);
    sum[o]=sum[o<<1]+sum[(o<<1)+1];
    maxnum[o]=max(maxnum[o<<1],maxnum[(o<<1)+1]);
    return;
}
int get_sum(int l,int r,int o){
    if(qx<=l&&r<=qy){
        return sum[o];
    }
    int mid=((r-l)>>1)+l,ans=0;
    if(qx<=mid){
        ans+=get_sum(l,mid,o<<1);
    }
    if(qy>mid){
        ans+=get_sum(mid+1,r,(o<<1)+1);
    }
    return ans;
}
int get_max(int l,int r,int o){
    if(qx<=l&&r<=qy){
        return maxnum[o];
    }
    int mid=((r-l)>>1)+l,ans=0;
    if(qx<=mid){
        ans=max(ans,get_max(l,mid,o<<1));
    }
    if(mid<qy)
        ans=max(ans,get_max(mid+1,r,(o<<1)+1));
    return ans;
}
void init(){
    DFS1(1);
    DFS2(1,1);
    build(1,n,1);
}
void MAX(int x,int y){
    int f1=top[x];
    int f2=top[y];
    int ans=0;
    while(f1!=f2){
        if(dep[f1]>dep[f2]){
            qx=id[f1],qy=id[x];
            ans=max(ans,get_max(1,n,1));
            x=fa[f1];
        }
        else{
            qx=id[f2],qy=id[y];
            ans=max(ans,get_max(1,n,1));
            y=fa[f2];
        }
        f1=top[x];
        f2=top[y];
    }
    if(dep[x]>dep[y])swap(x,y);
    qx=id[x],qy=id[y];
    ans=max(ans,get_max(1,n,1));
    printf("%d\n",ans);
}
void SUM(int x,int y){
    int f1=top[x];
    int f2=top[y];
    int ans=0;
    while(f1!=f2){
        if(dep[f1]>dep[f2]){
            qx=id[f1],qy=id[x];
            ans+=get_sum(1,n,1);
            x=fa[f1];
        }
        else{
            qx=id[f2],qy=id[y];
            ans+=get_sum(1,n,1);
            y=fa[f2];
        }
        f1=top[x];
        f2=top[y];
    }
    if(dep[x]>dep[y])swap(x,y);
    qx=id[x],qy=id[y];
    ans+=get_sum(1,n,1);
    printf("%d\n",ans);
}
void work(int op,int x,int y){
    if(op==0){
        qx=id[x],qy=id[x],qd=y;
        UPDATE(1,n,1);
    }
    else if(op==1){
        MAX(x,y);
    }
    else{
        SUM(x,y);
    }
    return;
}
int main(){
    /*freopen("input.txt","r",stdin); freopen("output.txt","w",stdout);*/
    scanf("%d%d",&n,&m);
    for(int i=1;i<=n;i++)
        scanf("%d",&line[i]);
    int x,y;
    char ch[10];
    for(int i=0;i<n-1;i++){
        scanf("%d%d",&x,&y);
        geo[y].push_back(x);
        geo[x].push_back(y);
    }
    init();
    for(int i=0;i<m;i++){
        scanf("%s%d%d",ch,&x,&y);
        if(ch[0]=='U'){
            work(0,x,y);
        }
        else if(ch[0]=='M'){
            work(1,x,y);
        }
        else{
            work(2,x,y);
        }
    }
    return 0;
}

下面这个是带两点路径查询/修改以及子树查询/修改的代码,查询时查找SUM,同时对M取模(这其实是洛谷上的模板题的板子,破事巨多,数据范围还有问题,而且容易爆栈,真是服了)

#include<cstdlib>
#include<cstdio>
#include<algorithm>
#include<vector>
#define maxn 200005
#define time t
#define LL long long int
using namespace std;
int t=0,root,M,n,m,qx,qy,dep[maxn],size[maxn],fa[maxn],id[maxn],id2x[maxn],son[maxn],top[maxn];
LL qd,sum[maxn*2],/*maxnum[maxn*2],*/add[maxn*2],line[maxn];
vector<int> geo[maxn];
void DFS1(int u){
    dep[u]=dep[fa[u]]+1;
    size[u]=1;
    son[u]=0;
    for(int i=0;i<geo[u].size();i++){
        int op=geo[u][i];
        if(op==fa[u])continue;
        fa[op]=u;
        DFS1(op);
        size[u]+=size[op];
        if(size[son[u]]<size[op]){
            son[u]=op;
        }
    }
}
void DFS2(int x,int tp){
    top[x]=tp;
    id[x]=++time;
    id2x[time]=x;
    if(son[x])DFS2(son[x],tp);
    for(int i=0;i<geo[x].size();i++){
        int op=geo[x][i];
        if(op==fa[x]||op==son[x])continue;
        DFS2(op,op);
    }
}
void build(int l,int r,int o){
    if(l==r){
        sum[o]=line[id2x[l]]/*,maxnum[o]=line[id2x[l]]*/;
        return;
    }
    int mid=((r-l)>>1)+l;
    build(l,mid,o<<1);
    build(mid+1,r,(o<<1)+1);
    sum[o]=(sum[o<<1]+sum[(o<<1)+1])%M;
    //maxnum[o]=max(maxnum[o<<1],maxnum[(o<<1)+1]);
    return;
}
void update(int l,int r,int o){
    if(qx<=l&&r<=qy){
        add[o]=(add[o]+qd)%M;sum[o]=(sum[o]+(((r-l+1)%M)*qd%M))%M;//maxnum[o]+=a;
        return;
    }
    int mid=((r-l)>>1)+l;
    if(qx<=mid)update(l,mid,o<<1);
    if(mid<qy)update(mid+1,r,(o<<1)+1);
    sum[o]=((sum[o<<1]+sum[(o<<1)+1])%M+((r-l+1)%M*add[o])%M)%M;
    //maxnum[o]=max(maxnum[o<<1],maxnum[(o<<1)+1]);
    return;
}
LL get_sum(int l,int r,int o,LL a){
    if(qx<=l&&r<=qy){
        return (sum[o]+a%M*(r-l+1)%M)%M;
    }
    int mid=((r-l)>>1)+l;
    LL ans=0;
    if(qx<=mid){
        ans=(ans+get_sum(l,mid,o<<1,a+add[o]))%M;
    }
    if(qy>mid){
        ans=(ans+get_sum(mid+1,r,(o<<1)+1,a+add[o]))%M;
    }
    return ans;
}
/*int get_max(int l,int r,int o){ if(qx<=l&&r<=qy){ return maxnum[o]; } int mid=((r-l)>>1)+l,ans=0; if(qx<=mid){ ans=max(ans,get_max(l,mid,o<<1)); } if(mid<qy) ans=max(ans,get_max(mid+1,r,(o<<1)+1)); return ans; }*/
void init(){
    DFS1(root);
    DFS2(root,root);
    build(1,n,1);
}
void UPDATE(int x,int y,int z){
    int f1=top[x];
    int f2=top[y];
    while(f1!=f2){
        if(dep[f1]>dep[f2]){
            qx=id[f1],qy=id[x],qd=z;
            update(1,n,1);
            x=fa[f1];
        }
        else{
            qx=id[f2],qy=id[y],qd=z;
            update(1,n,1);
            y=fa[f2];
        }
        f1=top[x];
        f2=top[y];
    }
    if(dep[x]>dep[y])swap(x,y);
    qx=id[x],qy=id[y],qd=z;
    update(1,n,1);
}
/*void MAX(int x,int y){ int f1=top[x]; int f2=top[y]; int ans=0; while(f1!=f2){ if(dep[f1]>dep[f2]){ qx=id[f1],qy=id[x]; ans=max(ans,get_max(1,n,1,0)); x=fa[f1]; } else{ qx=id[f2],qy=id[y]; ans=max(ans,get_max(1,n,1,0)); y=fa[f2]; } f1=top[x]; f2=top[y]; } if(dep[x]>dep[y])swap(x,y); qx=id[x],qy=id[y]; ans=max(ans,get_max(1,n,1,0)); printf("%d\n",ans); }*/
void SUM(int x,int y){
    int f1=top[x];
    int f2=top[y];
    int ans=0;
    while(f1!=f2){
        if(dep[f1]>dep[f2]){
            qx=id[f1],qy=id[x];
            ans=(ans+get_sum(1,n,1,0))%M;
            x=fa[f1];
        }
        else{
            qx=id[f2],qy=id[y];
            ans=(ans+get_sum(1,n,1,0))%M;
            y=fa[f2];
        }
        f1=top[x];
        f2=top[y];
    }
    if(dep[x]>dep[y])swap(x,y);
    qx=id[x],qy=id[y];
    ans=(ans+get_sum(1,n,1,0))%M;
    printf("%lld\n",ans%M);
}
void SUBTREE_UPDATE(int x,int z){
    qx=id[x],qy=id[x]+size[x]-1,qd=z;
    update(1,n,1);
}
void SUBTREE_SUM(int x){
    qx=id[x],qy=id[x]+size[x]-1;
    printf("%lld\n",get_sum(1,n,1,0)%M); 
}
int main(){
    /*freopen("input.txt","r",stdin); freopen("output.txt","w",stdout);*/
    scanf("%d%d%d%d",&n,&m,&root,&M);
    for(int i=1;i<=n;i++)
        scanf("%d",&line[i]);
    LL x,y,z,op;
    for(int i=0;i<n-1;i++){
        scanf("%d%d",&x,&y);
        geo[y].push_back(x);
        geo[x].push_back(y);
    }
    init();
    for(int i=0;i<m;i++){
        scanf("%d",&op);
        if(op==1){
            scanf("%d%d%lld",&x,&y,&z);
            UPDATE(x,y,z);
        }
        else if(op==2){
            scanf("%d%d",&x,&y);
            SUM(x,y);
        }
        else if(op==3){
            scanf("%d%lld",&x,&y);
            SUBTREE_UPDATE(x,y);
        }
        else{
            scanf("%d",&x);
            SUBTREE_SUM(x);
        }
    }
    return 0;
}

然而本人代码还是非常地丑QAQ

细节

  1. 需要注意的就是线段数的那个部分以及查询/修改刚开始的部分,每次判断 top 的深度,然后不断地向上移动

  2. 就是DFS什么的别写错就行啦

总结

树链剖分是个好东西,我们可以用它来搞许多事情