树链剖分,就是把树上的边划分,然后用数据结构维护的算法

重链剖分:轻边与重边

图片说明
百度百科的图
一个节点子树中节点数最多的节点叫做它的重儿子,与它相连的边称为重边
多条重边连起来成为重链
我们把重链连续标号,用数据结构维护,即把树上问题转化成了序列上问题,每一个对应了一个节点
的重儿子,的深度,所在重链的顶端,的父亲,序(按重儿子优先,即使重链标号连续)

性质

一条根节点到节点路径上重链、轻链数不超过

证明

重链轻链交替
每次走轻链,至少减少一半
所以轻链数
重链数也为

操作

对于修改两点间路径的操作,我们这样做
先跳大的节点,避免跳过头
做修改
再将赋值为
这样做直到在同一重链上
对于在同一重链的情况,我们会发现浅的即为
所以再次修改
查询同理,把数据结构的操作换成查询

时间复杂度

怎样实现

第一遍,处理
第二遍,处理

代码

[ZJOI2008 树的统计]
https://www.luogu.com.cn/problem/P2590

#include <bits/stdc++.h>
#define ri register int
#define ll long long
using namespace std;
const int Maxn=3e4;
const int Inf=3e4;
int lt[Maxn+5],ed[2*Maxn+5],nt[2*Maxn+5],val[Maxn+5],n,q,cnt;
int dep[Maxn+5],size[Maxn+5],son[Maxn+5],dfn[Maxn+5],top[Maxn+5],father[Maxn+5],rk[Maxn+5],vt;
struct SegTree {
    #define ls(p) (p<<1)
    #define rs(p) (p<<1|1)
    #define mid (((l)+(r))>>1)
    int v[(Maxn<<2)+5],mx[(Maxn<<2)+5];
    void update(int p) {
        v[p]=v[ls(p)]+v[rs(p)];
        mx[p]=max(mx[ls(p)],mx[rs(p)]);
    }
    void build(int p,int l,int r) {
        if(l==r) {
            v[p]=mx[p]=val[rk[l]];return ;
        }
        build(ls(p),l,mid);
        build(rs(p),mid+1,r);
        update(p);
    }
    void change(int p,int l,int r,int k,int d) {
        if(l==r) {
            v[p]=d,mx[p]=d;return ;
        }
        if(k<=mid)change(ls(p),l,mid,k,d);
        else change(rs(p),mid+1,r,k,d);
        update(p);
    }
    int getmax(int p,int l,int r,int L,int R) {
        if(L<=l&&r<=R)return mx[p];
        int ret=-Inf;
        if(L<=mid)ret=max(ret,getmax(ls(p),l,mid,L,R));
        if(R>mid)ret=max(ret,getmax(rs(p),mid+1,r,L,R));
        return ret;
    }
    int getsum(int p,int l,int r,int L,int R) {
        if(L<=l&&r<=R)return v[p];
        int ret=0;
        if(L<=mid)ret+=getsum(ls(p),l,mid,L,R);
        if(R>mid)ret+=getsum(rs(p),mid+1,r,L,R);
        return ret;
    }
}t;
void addedge(int x,int y) {
    ed[++cnt]=y;nt[cnt]=lt[x];lt[x]=cnt;
}
void dfs1(int u,int fa) {
    int maxn=0;
    dep[u]=dep[fa]+1;size[u]=1;father[u]=fa;
    for(ri i=lt[u];i;i=nt[i]) {
        int v=ed[i];
        if(v!=fa) {
            dfs1(v,u);
            size[u]+=size[v];
            if(size[v]>maxn) {
                maxn=size[v],son[u]=v;
            } 
        }
    }
}
void dfs2(int u,int fa,int now) {
    dfn[u]=++vt;rk[vt]=u;top[u]=now;
    if(son[u])dfs2(son[u],u,now);
    for(ri i=lt[u];i;i=nt[i]) {
        int v=ed[i];
        if(v!=fa&&v!=son[u])dfs2(v,u,v);
    }
}
int query_max(int u,int v) {
    int ans=-Inf;
    while(top[u]!=top[v]) {
        if(dep[top[u]]<dep[top[v]])swap(u,v);
        ans=max(ans,t.getmax(1,1,n,dfn[top[u]],dfn[u]));
        u=father[top[u]];
    }
    if(dep[u]>dep[v])swap(u,v);
    ans=max(ans,t.getmax(1,1,n,dfn[u],dfn[v]));
    return ans;
}
int query_sum(int u,int v) {
    int ans=0;
    while(top[u]!=top[v]) {
        if(dep[top[u]]<dep[top[v]])swap(u,v);
        ans+=t.getsum(1,1,n,dfn[top[u]],dfn[u]);
        u=father[top[u]];
    }
    if(dep[u]>dep[v])swap(u,v);
    ans+=t.getsum(1,1,n,dfn[u],dfn[v]);
    return ans;
}
int main() {
    scanf("%d",&n);
    for(ri i=1;i<n;i++) {
        int x,y;
        scanf("%d%d",&x,&y);
        addedge(x,y);
        addedge(y,x);
    }
    for(ri i=1;i<=n;i++)scanf("%d",&val[i]);
    dfs1(1,0);dfs2(1,0,1);t.build(1,1,n);
    scanf("%d",&q);
    while(q--) {
        char s[6];
        scanf("%s",s);
        int x,y;
        scanf("%d%d",&x,&y);
        if(s[0]=='C')t.change(1,1,n,dfn[x],y);
        else if(s[1]=='M')printf("%d\n",query_max(x,y));
        else printf("%d\n",query_sum(x,y));
    }
    return 0;
}