线段树学习笔记

前言

这篇博客是我这篇博客的延续作品(窝是先学习树状数组的)。在这篇博客中,将会十分浅、浅的不得了、不能再浅地浅谈一下线段树。好了,废话不多说了,我们开始吧

  • 线段树引入

我们先把树状数组的那张图拿来:

404

然后我们把它东拉西扯一下,就变成了这个东西:

404

(注意:线段树的所有节点都是实的,所以开数组的时候要注意扩大一些)

线段树的作用和树状数组几乎一样,用树状数组做的题目基本上都可以用线段树来做,而用线段树做的东西却不一定能用树状数组来做。所以学一下线段树是很有必要的。

  • 线段树的性质

线段树是一种二叉搜索树,与区间树相似,它将一个区间划分成一些单元区间,每个单元区间对应线段树中的一个叶结点。(摘自百度百科)

线段树线段树的每一个结点储存的是它左右儿子节点的信息(和,最大值...都可以)。

404

  • 建树

接下来,我们要构建那么一棵线段树。建树过程我们选择从上往下进行

考虑:知道一个结点,要怎么计算出它的左儿子和右儿子呢??

其实仔细观察图形就可以发现:

先放代码:

int sumv[N<<2];//这里表示线段树的结点
void pushup(int o){sumv[o]=sumv[o<<1]+sumv[o<<1|1];}//此函数是把信息上放
//o<<1 -> o*2
//o<<1|1 -> o*n+1
void build(int o,int l,int r){//o表示当前结点,l与r表示当前节点所控制的左右区间
    if(l==r){sumv[o]=a[l];return;}//如果l==r(就是指最底层时),那么直接赋值
    int mid=(l+r)>>1;//取中间节点,为了后面找左右儿子
    build(o<<1,l,mid);//左儿子
    build(o<<1|1,mid+1,r);//右儿子
    pushup(o);//我们每搜完一个结点的左右儿子,就要上放一下信息
}
调用函数:build(1,1,n);

函数表示:我们已经搜完左右儿子,那么我们就把当前结点的值表示成左儿子和右儿子的和(其实线段树不一定存和,但为了好解释,这里我们就用和来解释了。)(换句话说,只要你修改了它的左或右儿子,就要进行pushup)

这样我们就完成了建树的过程。

单点修改,区间查询

  • 单点修改

那么我们开始讨论单点修改

先放代码:

int sumv[N<<2];
void pushup(int o){sumv[o]=sumv[o<<1]+sumv[o<<1|1];}
void change(int o,int l,int r,int q,int v){//当前结点o,左右控制的区间l,r,在q的位置加上v
    if(l==r){sumv[o]+=v;return;}//如果是最后一排,直接改变
    int mid=(l+r)>>1;//取中间点(反正只要是线段树,肯定是要取中间点的)
    if(q<=mid)change(o<<1,l,mid,q,v);如果左儿子包含这个要改变的结点,那就搜左儿子
    else change(o<<1|1,mid+1,r,q,v);//否则就搜右儿子
    pushup(o); //上放信息别忘!
}
调用函数:change(1,1,n,q,v);

这样我们就完成了单点修改的过程。

  • 区间查询

那么我们要求一个区间的和呢?

先放代码:

int querysum(int o,int l,int r,int ql,int qr){//当前结点o,左右控制的区间l,r,查询的区间ql,qr

/*
求区间和要分类讨论:
1、完全在[ql,qr]中 -> 此时直接返回结点的值
2、它的在[ql,qr]中有左儿子包含的区间 -> 此时遍历左儿子
3、它的在[ql,qr]中有右儿子包含的区间 -> 此时遍历右儿子
注意:这里不能用else if
因为三种情况是独立的,可能会满足多种情况
*/
    if(ql<=l&&r<=qr)return sumv[o];
    int ans=0;
    int mid=(l+r)>>1;
    if(ql<=mid)ans+=querysum(o<<1,l,mid,ql,qr);
    if(qr>mid)ans+=querysum(o<<1|1,mid+1,r,ql,qr);
    return ans;
    //没有进行修改,不用pushup
}

这样我们就完成了区间查询

再来放一份完整代码

#include<bits/stdc++.h>
using namespace std;
const int N=1e6+10;
int n,a[N],m;
int sumv[N<<2];
void pushup(int o){sumv[o]=sumv[o<<1]+sumv[o<<1|1];}
void build(int o,int l,int r){
    if(l==r){sumv[o]=a[l];return;}
    int mid=(l+r)>>1;
    build(o<<1,l,mid);
    build(o<<1|1,mid+1,r);
    pushup(o);
}
void change(int o,int l,int r,int q,int v){
    if(l==r){sumv[o]+=v;return;}
    int mid=(l+r)>>1;
    if(q<=mid)change(o<<1,l,mid,q,v);
    else change(o<<1|1,mid+1,r,q,v);
    pushup(o); 
}
int querysum(int o,int l,int r,int ql,int qr){
    if(ql<=l&&r<=qr)return sumv[o];
    int ans=0;
    int mid=(l+r)>>1;
    if(ql<=mid)ans+=querysum(o<<1,l,mid,ql,qr);
    if(qr>mid)ans+=querysum(o<<1|1,mid+1,r,ql,qr);
    return ans;
}
int main(){
    scanf("%d%d",&n,&m);
    for(int i=1;i<=n;i++)scanf("%d",a+i);
    build(1,1,n);
    while(m--){
        int opt;scanf("%d",&opt);
        if(opt==1){int x,y;scanf("%d%d",&x,&y);change(1,1,n,x,y);}
        if(opt==2){int l,r;scanf("%d%d",&l,&r);printf("%d\n",querysum(1,1,n,l,r));}
    }
}

看上去结束了??

没有!

稍微难一点?

如果要把区间内所有的数都加上一个,那么怎么做呢?

for循环单点修改?

直接暴力地做单点加显然不可行

那么我们考虑到一个问题,就是每一个点的值我们只有用到它了,对他做的修改才会有意义

因此我们可以先记录下来一个点的被加上的值,相当于为它打上一个标记,等待稍后处理

然后等到用的时候,处理好这个标记即可

这种思想叫做延迟处理(也称),在线段树上的应用是它最出名的应用

区间修改,区间查询

(这个就难多了,自己也不是很理解,可能写的有点模糊(甚至写错))

那么我们的建树过程还是不会变的。会变的是修改操作

这里我们引入一个新数组用来保存每个结点所修改的值。

我们要记住一个很重要的点:如果要修改一个点(或区间),肯定时同步修改的,不可能只修改其中的一个

  • 函数

    inline void puttag(int o,ll v,int l,int r){addv[o]+=v;sumv[o]+=v*(r-l+1);}
    //当前结点o,左右控制的区间l,r,当前节点所表示的区间需要修改v。
    区间长度r-l+1不解释
  • 函数

字面意思,把一个结点的信息下放。不过为什么我们又要下放了呢?(这不是退化了吗)

因为区间修改遍历一个点的左右儿子时,需要用到这些信息(因为修改是从上往下的)

为什么单点修改不用下放?细节留给读者思考

void pushdown(int o,int l,int r){
    if(addv[o]==0)return;//如果根本没有修改过,直接return
    addv[o<<1]+=addv[o];//往左儿子下放修改的值
    addv[o<<1|1]+=addv[o];//往右儿子下放修改的值
    int mid=(l+r)>>1;
    sumv[o<<1]+=addv[o]*(mid-l+1);//下放区间值
    sumv[o<<1|1]+=addv[o]*(r-mid);//下放区间值
    addv[o]=0;//清空
}
  • 区间修改

有了以上的处理,就简单很多了

void optadd(int o,int l,int r,int ql,int qr,int v){
//和单点修改性质差不多,不做解释
    if(ql<=l&&r<=qr){puttag(o,l,r,v);return;}
    int mid=(l+r)>>1;
    pushdown(o,l,r);
    if(ql<=mid)optadd(o<<1,l,mid,ql,qr,v);
    if(qr>mid)optadd(o<<1|1,mid+1,r,ql,qr,v);
    pushup(o);
}
  • 区间查询

和原来基本一样

inline ll querysum(int o,int l,int r,int ql,int qr)
{
    if(ql<=l&&r<=qr)return sumv[o];
    pushdown(o,l,r);//要遍历左儿子和右儿子啦,下放
    ll ans=0;
    int mid=(l+r)>>1;
    if(ql<=mid)ans+=querysum(o<<1,l,mid,ql,qr);
    if(qr>mid)ans+=querysum(o<<1|1,mid+1,r,ql,qr);
    return ans;
}

完整代码

#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int MAXN=1e5+10;
ll sumv[MAXN<<2];
ll addv[MAXN<<2];
int n,m;
ll a[MAXN];
inline ll read()
{
    ll tot=0;int f=1;char c=getchar();
    while(c<'0'||c>'9'){if(c=='-')f=-1;c=getchar();}
    while(c>='0'&&c<='9'){tot=(tot<<1)+(tot<<3)+c-'0';c=getchar();}
    return tot*f;
}
inline void pushup(int o){sumv[o]=sumv[o<<1]+sumv[o<<1|1];}
inline void pushdown(int o,int l,int r)
{
    if(addv[o]==0)return;
    addv[o<<1]+=addv[o];
    addv[o<<1|1]+=addv[o];
    int mid=(l+r)>>1;
    sumv[o<<1]+=addv[o]*(mid-l+1);
    sumv[o<<1|1]+=addv[o]*(r-mid);
    addv[o]=0;
}
inline void puttag(int o,ll v,int l,int r){addv[o]+=v;sumv[o]+=v*(r-l+1);}
inline void build(int o,int l,int r)
{
    addv[o]=0;
    if(l==r){sumv[o]=a[l];return;}
    int mid=(l+r)>>1;
    build(o<<1,l,mid);
    build(o<<1|1,mid+1,r);
    pushup(o);
}
inline void optadd(int o,int l,int r,int ql,int qr,int v)
{
    if(ql<=l&&r<=qr){puttag(o,v,l,r);return;}
    pushdown(o,l,r);
    int mid=(l+r)>>1;
    if(ql<=mid)optadd(o<<1,l,mid,ql,qr,v);
    if(qr>mid)optadd(o<<1|1,mid+1,r,ql,qr,v);
    pushup(o);
}
inline ll querysum(int o,int l,int r,int ql,int qr)
{
    if(ql<=l&&r<=qr)return sumv[o];
    pushdown(o,l,r);
    ll ans=0;
    int mid=(l+r)>>1;
    if(ql<=mid)ans+=querysum(o<<1,l,mid,ql,qr);
    if(qr>mid)ans+=querysum(o<<1|1,mid+1,r,ql,qr);
    return ans;
}
int main()
{
    n=read();m=read();
    //cout<<n<<" "<<m<<endl;
    for(int i=1;i<=n;i++)a[i]=read();
    build(1,1,n);
    int p,l,r,v;
    for(int i=1;i<=m;i++)
    {
        p=read();
        if(p==1)l=read(),r=read(),v=read(),optadd(1,1,n,l,r,v);
        else l=read(),r=read(),printf("%lld\n",querysum(1,1,n,l,r));
    }
    return 0;
}

后记

此篇博客是我第一次尝试理解线段树,写的不好请见谅。补充说明:对于单点查询,实际就是区间查询的一个子问题。就是长度为1的区间嘛。

听说线段树还可以区间乘上一个数,但由于作者太菜,还是没能搞懂。(主席树和树链剖分都要听睡着了呢)