概念

假设有编号从1到n的n个点,每个点都存了一些信息,用[L,R]表示下标从L到R的这些点。线段树的用处就是,对编号连续的一些点进行修改或者统计操作,修改和统计的复杂度都是O(log(n)).线段树的原理,就是,将[1,n]分解成若干特定的子区间(数量不超过4*n),然后,将每个区间[L,R]都分解为少量特定的子区间,通过对这些少量子区间的修改或者统计,来实现快速对[L,R]的修改或者统计。

要点

用线段树统计的东西,必须符合区间加法,否则,不可能通过分成的子区间来得到[L,R]的统计结果。
符合区间加法的例子:
数字之和——总数字之和 = 左区间数字之和 + 右区间数字之和
最大公因数(GCD)——总GCD = gcd( 左区间GCD , 右区间GCD )
最大值——总最大值=max(左区间最大值,右区间最大值)

不符合区间加法的例子:
众数——只知道左右区间的众数,没法求总区间的众数
01序列的最长连续零——只知道左右区间的最长连续零,没法知道总的最长连续零

线段树存储结构

线段树是用数组来模拟树形结构,对于每一个节点R ,左子节点为 2R (一般写作R<<1)右子节点为 2R+1(一般写作R<<1|1). 然后以1为根节点,所以,整体的统计信息是存在节点1中的。线段树需要的数组元素个数是2^log(n)+1,一般都开4倍空间,比如:int A[n<<2];

线段树基本实现(递归方法)

这里以区间求和为例。

// 定义
#define maxn 100007  //元素总个数
int Sum[maxn<<2],Add[maxn<<2];//Sum求和,Add为懒惰标记 
int A[maxn],n;//存原数组数据,下标1到n 


// 建树
//PushUp函数更新节点信息 ,这里是求和
void PushUp(int rt){Sum[rt]=Sum[rt<<1]+Sum[rt<<1|1];}

//Build函数建树 
void Build(int l,int r,int rt){ //l,r表示当前节点区间,rt表示当前节点编号
    if(l==r) {//若到达叶节点 
        Sum[rt]=A[l];//储存数组值 
        return;
    }
    int m=(l+r)>>1;
    //左右递归 
    Build(l,m,rt<<1);
    Build(m+1,r,rt<<1|1);
    //更新信息 
    PushUp(rt);
}


// 点修改
// 假定 A[L]+=C
void Update(int L,int C,int l,int r,int rt){
        //l,r表示当前节点区间,rt表示当前节点编号
    if(l==r){//到叶节点,修改 
        Sum[rt]+=C;
        return;
    }
    int m=(l+r)>>1;
    //根据条件判断往左子树调用还是往右 
    if(L <= m) Update(L,C,l,m,rt<<1);
    else       Update(L,C,m+1,r,rt<<1|1);
    PushUp(rt);//子节点更新了,所以本节点也需要更新信息 
} 


// 区间修改
// 假定A[L,R]+=C
void Update(int L,int R,int C,int l,int r,int rt){//L,R表示操作区间,l,r表示当前节点区间,rt表示当前节点编号 
    if(L <= l && r <= R){//如果本区间完全在操作区间[L,R]以内 
        Sum[rt]+=C*(r-l+1);//更新数字和,向上保持正确
        Add[rt]+=C;//增加Add标记,表示本区间的Sum正确,子区间的Sum仍需要根据Add的值来调整
        return ; 
    }
    int m=(l+r)>>1;
    PushDown(rt,m-l+1,r-m);//下推标记
    //这里判断左右子树跟[L,R]有无交集,有交集才递归 
    if(L <= m) Update(L,R,C,l,m,rt<<1);
    if(R >  m) Update(L,R,C,m+1,r,rt<<1|1); 
    PushUp(rt);//更新本节点信息 
} 


// 区间查询 询问A[L,R]这个范围的和
// 首先是下推函数
void PushDown(int rt,int ln,int rn){
    //ln,rn为左子树,右子树的数字数量
    if(Add[rt]){
        //下推标记 
        Add[rt<<1]+=Add[rt];
        Add[rt<<1|1]+=Add[rt];
        //修改子节点的Sum使之与对应的Add相对应 
        Sum[rt<<1]+=Add[rt]*ln;
        Sum[rt<<1|1]+=Add[rt]*rn;
        //清除本节点标记 
        Add[rt]=0;
    }
}
// 然后是查询函数
int Query(int L,int R,int l,int r,int rt){//L,R表示操作区间,l,r表示当前节点区间,rt表示当前节点编号
    if(L <= l && r <= R){
        //在区间内,直接返回 
        return Sum[rt];
    }
    int m=(l+r)>>1;
    //下推标记,否则Sum可能不正确
    PushDown(rt,m-l+1,r-m); 

    //累计答案
    int ANS=0;
    if(L <= m) ANS+=Query(L,R,l,m,rt<<1);
    if(R >  m) ANS+=Query(L,R,m+1,r,rt<<1|1);
    return ANS;
} 

// 整个调用过程

    //建树 
    Build(1,n,1); 
    //点修改
    Update(L,C,1,n,1);
    //区间修改 
    Update(L,R,C,1,n,1);
    //区间查询 
    int ANS=Query(L,R,1,n,1);

例题1

图片说明

#include<bits/stdc++.h>
using namespace std;

#define max_n 100001
// 原始数据
int seg[max_n];
// 最大值数组
int arr[4*max_n];
// p当前节点
// l,r当前节点的区间
void create(int l,int r,int p)
{
    if(l==r)
    {
        arr[p]=seg[l];
        return;
    }
    int m = (l+r)>>1;
    create(l,m,p<<1);
    create(m+1,r,p<<1|1);
    arr[p] = max(arr[p<<1],arr[p<<1|1]);
}
// 修改点
void update(int idx,int val,int l,int r,int p)
{
    if(l==r&&l==idx)
    {
        arr[p]=val;
        return;
    }
    int m = (l+r)>>1;
    if(idx<=m) update(idx,val,l,m,p<<1);
    if(idx>m) update(idx,val,m+1,r,p<<1|1);
    arr[p] = max(arr[p<<1],arr[p<<1|1]);
}
// 区间查询
int query(int L,int R,int l,int r,int p)
{
    if(L<=l&&R>=r) return arr[p];
    int ans = -1;
    int m = (l+r)>>1;
    if(L<=m) ans = max(ans,query(L,R,l,m,p<<1));
    if(R>m) ans = max(ans,query(L,R,m+1,r,p<<1|1));
    return ans;
}
int main()
{

int n,m;
    while(cin>>n>>m){
        for(int i=1;i<=n;i++){
            scanf("%d",&seg[i]);
        }
        create(1,n,1);
        char order;
        int a,b;
        for(int i=1;i<=m;++i){
            cin>>order>>a>>b;

            if(order=='U'){

                update(a,b,1,n,1);
            }else if(order=='Q'){

                if(a>b)swap(a,b);
                printf("%d\n",query(a,b,1,n,1));
            }
        }
    }
    return 0;
}

例题2

图片说明
也可用dp方法来做。

// 线段树总区间 [1,n]
// 线段树节点结构体
struct asd{
    int left,right,mid;  // 当前节点的左右区间范围和中点
    int val;  // 所有数据中落在当前区间的那些值构成的递增序列的最大长度
    int cnt;  // 组成这个最大长度的方案数
};
vector<asd>q;

class Solution {
public:

    // 建树
    void build(int i,int left,int right)
{
    q[i].left=left;
    q[i].right=right;
    q[i].mid=(left+right)>>1;
    q[i].cnt=q[i].val=0;
    if(left==right)
        return;
    build(i<<1,left,q[i].mid);
    build(i<<1|1,q[i].mid+1,right);
}
// 查询落在区间[left,right]内的最大长度和方案数
void query(int i, int left, int right, int &x, int &y )
{
    if(q[i].left==left&&q[i].right==right)
    {
        x=q[i].val;
        y=q[i].cnt;
        return;
    }
    if(right<=q[i].mid)
    {
        query(i<<1, left, right, x, y);
        return;
    }
    if(left>q[i].mid)
    {
        query(i<<1|1,left,right, x, y);
        return;
    }
    int lx,ly;
    int rx,ry;
    query(i<<1,left,q[i].mid,lx,ly);
    query(i<<1|1,q[i].mid+1,right, rx,ry);
    if(lx==rx)
    {
        x=lx;
        y=(ly+ry);
    }
    else if(lx>rx)
    {
        x=lx;
        y=ly;
    }
    else
    {
        x=rx;
        y=ry;
    }
}
// 用查询的结果更新整棵树
void update(int i,int p,int x,int y)
{
    if(q[i].left==q[i].right&&q[i].left==p)
    {
        if(x>q[i].val)
        {
            q[i].val=x;
            q[i].cnt=y;
        }
        else if(x==q[i].val)
        {
            q[i].cnt=(q[i].cnt+y);
        }
        return;
    }
    if(p<=q[i].mid)
        update(i<<1,p,x,y);
    else if(p>q[i].mid)
        update(i<<1|1,p,x,y);
    if(q[i<<1].val==q[i<<1|1].val)
    {
        q[i].val=q[i<<1].val;
        q[i].cnt=(q[i<<1].cnt+q[i<<1|1].cnt);
    }
    else if(q[i<<1].val>q[i<<1|1].val)
    {
        q[i].val=q[i<<1].val;
        q[i].cnt=q[i<<1].cnt;
    }
    else
    {
        q[i].val=q[i<<1|1].val;
        q[i].cnt=q[i<<1|1].cnt;
    }
}

    int findNumberOfLIS(vector<int>& nums) {
        if(nums.empty()) return 0;
        int n = nums.size();
        // 保存原始数据
        vector<int>arr;
        arr.resize(n+1);
        q.resize(n*4+5);
        vector<int>xs;
        for(int i=1;i<=n;i++)
    {
        arr[i] = nums[i-1];
        xs.push_back(arr[i]);
    }
        // 排序 + 去重 + 离散化
    sort(xs.begin(),xs.end());
    auto e=unique(xs.begin(),xs.end());
    for(int i=1;i<=n;i++)
        arr[i]=lower_bound(xs.begin(), e, arr[i])-xs.begin()+1;

    build(1,0,n);
    // 每次查询0~arr[i-1]这个区间范围的最大值x
    // 以 x+1 去更新
    for(int i=1;i<=n;i++)
    {
        int x,y;
        query(1,0,arr[i]-1,x,y);
        //cout<<x<<" "<<y<<endl;
         if(!y)
            y=1;
        update(1, arr[i] , x+1, y);
    }
    int x,y;
    query(1,0,n,x,y);
    return y;

    }
};

例题3

图片说明

思路:每掉落一个方块,更新该方块落在的区间的最大高度。利用线段树进行区间修改和区间最大值查询。

class Solution {
public:

    map<int,int>mp;
    const int MX = 2009;
    int mx[4 * 3000];
    // 查询区间最大值
    int query(int L,int R,int l,int r,int rt)
    {
        if(L<=l&&R>=r) return mx[rt];
        int mid = (l+r)>>1;
        int ans = 0;
        if(mid>=L) ans = max(ans,query(L,R,l,mid,rt<<1));
        if(R>mid) ans = max(ans,query(L,R,mid+1,r,rt<<1|1));
        return ans;
    }
    // 区间修改
    void update(int L,int R,int l,int r,int rt,int val)
    {
        if(l==r)
        {
            mx[rt]=val;
            return;
        }
        int mid = (l+r)>>1;
        if(mid>=L) update(L,R,l,mid,rt<<1,val);
        if(R>mid) update(L,R,mid+1,r,rt<<1|1,val);
        mx[rt] = max(mx[rt<<1],mx[rt<<1|1]);
    }
    vector<int> fallingSquares(vector<vector<int>>& positions) {
        vector<int>ans;
        int n = positions.size();
        //memset(mx,0,sizeof(mx));
        // 离散化
        for(int i=0;i<n;++i)
        {
            mp[positions[i][0]] = 1;
            mp[positions[i][0]+positions[i][1]-1] = 1;
        }
        int cnt = 1;
        for(auto& t:mp)
            t.second = cnt++;
        int l;
        int r;
        for(int i=0;i<n;++i)
        {
            l = mp[positions[i][0]];
            r = mp[positions[i][0]+positions[i][1]-1];
            int res = query(l,r,1,2009,1);
            update(l,r,1,2009,1,res+positions[i][1]);
            ans.push_back(mx[1]);
        }
        return ans;
    }
};

例4 张贴海报

图片说明

法一

比较容易想到的一个思路。每贴一张海报,将其长度所覆盖的那些位置的值改成这张海报的索引。最后遍历整个数组,统计不重复的海报编号数目即可。时间复杂度较高。

#include<bits/stdc++.h>
using namespace std;
const int max_n = 1e7;
int v[max_n];
int main()
{
    memset(v,0,sizeof(v));
    int n;
    cin>>n;
    int a,b;
        for(int i=1;i<=n;++i)
        {
            cin>>a>>b;
            for(int j=a;j<=b;++j)
                v[j] = i;
       }

    set<int>se;
    for(int i=1;i<max_n;++i)
        if(v[i])
            se.insert(v[i]);
    cout<<se.size()<<endl;
    return 0;
}

法二 线段树

通过线段树来维护区间。最后统计不重复的海报编号数目。

#include<bits/stdc++.h>
using namespace std;
const int max_n = 1e5;
int seg[4*max_n];
// 懒惰标记,用于延迟传播
int add[4*max_n];
// 建树
void create(int l,int r,int rt)
{
    if(l==r)
        return;
    create(l,(l+r)>>1,rt<<1);
    create(((l+r)>>1)+1,r,rt<<1|1);
}
// 下推函数
void pushDown(int rt)
{
    if(add[rt])
    {
        // 将标记下推到左右子区间
        add[rt<<1] = add[rt]; 
        add[rt<<1|1] = add[rt];
        // 清除本节点的标记
        add[rt] = 0;
    }
}
void update(int L,int R,int l,int r,int rt,int val)
{
    if(L<=l&&R>=r)
    {
        add[rt] = val; //子区间仍需要根据add值进行调整
        return;
    }
    //if(l==r)  //  叶子节点更新
    //{
    //    add[rt] = val;
    //    return;
    //}
    // 下推标记
    pushDown(rt);
    // 继续寻找区间
    int mid = (l+r)>>1;
    if(L<=mid) update(L,R,l,mid,rt<<1,val);
    if(R>mid) update(L,R,mid+1,r,rt<<1|1,val);

}
// 区间查询不重复的海报数目
void query(int l,int r,int rt,set<int>& se)
{
    if(add[rt])
    {
       // if(l==r)
       // {
            se.insert(add[rt]);
            return;
      //  }
       // else pushDown(rt);
    }
    // 到叶节点了 返回
    if(l==r) return;
    query(l,(l+r)>>1,rt<<1,se);
    query(((l+r)>>1)+1,r,rt<<1|1,se);
}
int main()
{
    int n;
    cin>>n;
    int a,b;
    vector<pair<int,int>>data;
    // 辅助数组
    vector<int>xs;
    for(int i=1;i<=n;++i)
    {
       cin>>a>>b;
       data.push_back({a,b});
       xs.push_back(a);
       xs.push_back(b);
    }
    // 排序
    sort(xs.begin(),xs.end());
    // 去重
    xs.erase(unique(xs.begin(),xs.end()),xs.end());
    create(1,xs.size(),1);
    set<int>se;
    for(int i=1;i<=n;++i)
    {
        // 离散化 节约线段树需要开辟的空间
       int l = lower_bound(xs.begin(),xs.end(),data[i-1].first)-xs.begin()+1;
       int r = lower_bound(xs.begin(),xs.end(),data[i-1].second)-xs.begin()+1;
        // 更新
       update(l,r,1,xs.size(),1,i);
    }
    // 统计数目
    query(1,xs.size(),1,se);
    cout<<se.size()<<endl;
    return 0;
}

例5

图片说明

思路

直接统计每个元素右侧比它小的元素个数,O(n^2),超时。此题可有多种解法。
线段树解法:离散化操作后,求得整个序列的最小值min和最大值max,线段树的整个区间为[min,max].
从数组的最后一位开始遍历,更新线段树。若当前元素nums[i],统计nums[i]右侧比它小的元素个数,即统计落在区间[min,nums[i]-1]这个区间的数的个数。故每遇见一个数,就将包含该数的区间的计数值+1,这样在查询时候,统计落在某个区间的数的个数就转化为了区间和问题。

按照此思路,题目可变形为求每个元素左/右侧比它大/小的元素个数。

const int max_n = 1e6;
//int seg[4*max_n];
int sum[4*max_n];
class Solution {
public:

    void update(int l,int r,int rt,int val)
    {
        if(l==r)
        {
            sum[rt] += 1;
            return;
        }
        int m = (l+r)>>1;
        if(val<=m)
            update(l,m,rt<<1,val);
        else 
            update(m+1,r,rt<<1|1,val);
        sum[rt] = sum[rt<<1]+sum[rt<<1|1];
    }
    int query(int L,int R,int l,int r, int rt)
    {
        if(L<=l&&R>=r)
        {
            return sum[rt];
        }
        int m = (l+r)>>1;
        int ans = 0;
        if(L<=m)ans+=query(L,R,l,m,rt<<1);
        if(R>m) ans+=query(L,R,m+1,r,rt<<1|1);
        return ans;
    }
    vector<int> countSmaller(vector<int>& nums) {
        vector<int>ans;
        if(nums.empty()) return ans;
        ans.resize(nums.size(),0);
        memset(sum,0,sizeof(sum));
        int MAX = INT_MIN;
        int MIN = INT_MAX;

        vector<int>xs;
        for(auto t:nums)
            xs.push_back(t);
        sort(nums.begin(),nums.end());
        // 离散化
        for(int i=0;i<xs.size();++i)
            xs[i] = (lower_bound(nums.begin(),nums.end(),xs[i])-nums.begin()+1);
        //for(auto a:xs) cout<<" "<<a;
        // 统计整个序列的最值
        for(auto t:xs){
            MAX = max(MAX,t);
            MIN = min(MIN,t);
        }
        for(int i=nums.size()-1;i>=0;--i)
        {
            update(MIN,MAX,1,xs[i]);
            if(xs[i]==MIN) continue;
            int cnt = query(MIN,xs[i]-1,MIN,MAX,1);
            ans[i] = cnt;
        }

        return ans;
    }
};