树链剖分,就是把树上的边划分,然后用数据结构维护的算法
重链剖分:轻边与重边
百度百科的图
一个节点子树中节点数最多的节点叫做它的重儿子,与它相连的边称为重边
多条重边连起来成为重链
我们把重链连续标号,用数据结构维护,即把树上问题转化成了序列上问题,每一个对应了一个节点
记为
的重儿子,
为
的深度,
为
所在重链的顶端,
为
的父亲,
为
的
序(按重儿子优先,即使重链标号连续)
性质
一条根节点到节点路径上重链、轻链数不超过
证明
重链轻链交替
每次走轻链,至少减少一半
所以轻链数
重链数也为
操作
对于修改两点间路径的操作,我们这样做
先跳大的节点,避免跳过头
把到
做修改
再将赋值为
这样做直到、
在同一重链上
对于在同一重链的情况,我们会发现浅的即为
所以再次修改到
查询同理,把数据结构的操作换成查询
时间复杂度
怎样实现
第一遍,处理
第二遍,处理
代码
[ZJOI2008 树的统计]
https://www.luogu.com.cn/problem/P2590
#include <bits/stdc++.h> #define ri register int #define ll long long using namespace std; const int Maxn=3e4; const int Inf=3e4; int lt[Maxn+5],ed[2*Maxn+5],nt[2*Maxn+5],val[Maxn+5],n,q,cnt; int dep[Maxn+5],size[Maxn+5],son[Maxn+5],dfn[Maxn+5],top[Maxn+5],father[Maxn+5],rk[Maxn+5],vt; struct SegTree { #define ls(p) (p<<1) #define rs(p) (p<<1|1) #define mid (((l)+(r))>>1) int v[(Maxn<<2)+5],mx[(Maxn<<2)+5]; void update(int p) { v[p]=v[ls(p)]+v[rs(p)]; mx[p]=max(mx[ls(p)],mx[rs(p)]); } void build(int p,int l,int r) { if(l==r) { v[p]=mx[p]=val[rk[l]];return ; } build(ls(p),l,mid); build(rs(p),mid+1,r); update(p); } void change(int p,int l,int r,int k,int d) { if(l==r) { v[p]=d,mx[p]=d;return ; } if(k<=mid)change(ls(p),l,mid,k,d); else change(rs(p),mid+1,r,k,d); update(p); } int getmax(int p,int l,int r,int L,int R) { if(L<=l&&r<=R)return mx[p]; int ret=-Inf; if(L<=mid)ret=max(ret,getmax(ls(p),l,mid,L,R)); if(R>mid)ret=max(ret,getmax(rs(p),mid+1,r,L,R)); return ret; } int getsum(int p,int l,int r,int L,int R) { if(L<=l&&r<=R)return v[p]; int ret=0; if(L<=mid)ret+=getsum(ls(p),l,mid,L,R); if(R>mid)ret+=getsum(rs(p),mid+1,r,L,R); return ret; } }t; void addedge(int x,int y) { ed[++cnt]=y;nt[cnt]=lt[x];lt[x]=cnt; } void dfs1(int u,int fa) { int maxn=0; dep[u]=dep[fa]+1;size[u]=1;father[u]=fa; for(ri i=lt[u];i;i=nt[i]) { int v=ed[i]; if(v!=fa) { dfs1(v,u); size[u]+=size[v]; if(size[v]>maxn) { maxn=size[v],son[u]=v; } } } } void dfs2(int u,int fa,int now) { dfn[u]=++vt;rk[vt]=u;top[u]=now; if(son[u])dfs2(son[u],u,now); for(ri i=lt[u];i;i=nt[i]) { int v=ed[i]; if(v!=fa&&v!=son[u])dfs2(v,u,v); } } int query_max(int u,int v) { int ans=-Inf; while(top[u]!=top[v]) { if(dep[top[u]]<dep[top[v]])swap(u,v); ans=max(ans,t.getmax(1,1,n,dfn[top[u]],dfn[u])); u=father[top[u]]; } if(dep[u]>dep[v])swap(u,v); ans=max(ans,t.getmax(1,1,n,dfn[u],dfn[v])); return ans; } int query_sum(int u,int v) { int ans=0; while(top[u]!=top[v]) { if(dep[top[u]]<dep[top[v]])swap(u,v); ans+=t.getsum(1,1,n,dfn[top[u]],dfn[u]); u=father[top[u]]; } if(dep[u]>dep[v])swap(u,v); ans+=t.getsum(1,1,n,dfn[u],dfn[v]); return ans; } int main() { scanf("%d",&n); for(ri i=1;i<n;i++) { int x,y; scanf("%d%d",&x,&y); addedge(x,y); addedge(y,x); } for(ri i=1;i<=n;i++)scanf("%d",&val[i]); dfs1(1,0);dfs2(1,0,1);t.build(1,1,n); scanf("%d",&q); while(q--) { char s[6]; scanf("%s",s); int x,y; scanf("%d%d",&x,&y); if(s[0]=='C')t.change(1,1,n,dfn[x],y); else if(s[1]=='M')printf("%d\n",query_max(x,y)); else printf("%d\n",query_sum(x,y)); } return 0; }