概念
假设有编号从1到n的n个点,每个点都存了一些信息,用[L,R]表示下标从L到R的这些点。线段树的用处就是,对编号连续的一些点进行修改或者统计操作,修改和统计的复杂度都是O(log(n)).线段树的原理,就是,将[1,n]分解成若干特定的子区间(数量不超过4*n),然后,将每个区间[L,R]都分解为少量特定的子区间,通过对这些少量子区间的修改或者统计,来实现快速对[L,R]的修改或者统计。
要点
用线段树统计的东西,必须符合区间加法,否则,不可能通过分成的子区间来得到[L,R]的统计结果。
符合区间加法的例子:
数字之和——总数字之和 = 左区间数字之和 + 右区间数字之和
最大公因数(GCD)——总GCD = gcd( 左区间GCD , 右区间GCD )
最大值——总最大值=max(左区间最大值,右区间最大值)
不符合区间加法的例子:
众数——只知道左右区间的众数,没法求总区间的众数
01序列的最长连续零——只知道左右区间的最长连续零,没法知道总的最长连续零
线段树存储结构
线段树是用数组来模拟树形结构,对于每一个节点R ,左子节点为 2R (一般写作R<<1)右子节点为 2R+1(一般写作R<<1|1). 然后以1为根节点,所以,整体的统计信息是存在节点1中的。线段树需要的数组元素个数是2^log(n)+1,一般都开4倍空间,比如:int A[n<<2];
线段树基本实现(递归方法)
这里以区间求和为例。
// 定义
#define maxn 100007 //元素总个数
int Sum[maxn<<2],Add[maxn<<2];//Sum求和,Add为懒惰标记
int A[maxn],n;//存原数组数据,下标1到n
// 建树
//PushUp函数更新节点信息 ,这里是求和
void PushUp(int rt){Sum[rt]=Sum[rt<<1]+Sum[rt<<1|1];}
//Build函数建树
void Build(int l,int r,int rt){ //l,r表示当前节点区间,rt表示当前节点编号
if(l==r) {//若到达叶节点
Sum[rt]=A[l];//储存数组值
return;
}
int m=(l+r)>>1;
//左右递归
Build(l,m,rt<<1);
Build(m+1,r,rt<<1|1);
//更新信息
PushUp(rt);
}
// 点修改
// 假定 A[L]+=C
void Update(int L,int C,int l,int r,int rt){
//l,r表示当前节点区间,rt表示当前节点编号
if(l==r){//到叶节点,修改
Sum[rt]+=C;
return;
}
int m=(l+r)>>1;
//根据条件判断往左子树调用还是往右
if(L <= m) Update(L,C,l,m,rt<<1);
else Update(L,C,m+1,r,rt<<1|1);
PushUp(rt);//子节点更新了,所以本节点也需要更新信息
}
// 区间修改
// 假定A[L,R]+=C
void Update(int L,int R,int C,int l,int r,int rt){//L,R表示操作区间,l,r表示当前节点区间,rt表示当前节点编号
if(L <= l && r <= R){//如果本区间完全在操作区间[L,R]以内
Sum[rt]+=C*(r-l+1);//更新数字和,向上保持正确
Add[rt]+=C;//增加Add标记,表示本区间的Sum正确,子区间的Sum仍需要根据Add的值来调整
return ;
}
int m=(l+r)>>1;
PushDown(rt,m-l+1,r-m);//下推标记
//这里判断左右子树跟[L,R]有无交集,有交集才递归
if(L <= m) Update(L,R,C,l,m,rt<<1);
if(R > m) Update(L,R,C,m+1,r,rt<<1|1);
PushUp(rt);//更新本节点信息
}
// 区间查询 询问A[L,R]这个范围的和
// 首先是下推函数
void PushDown(int rt,int ln,int rn){
//ln,rn为左子树,右子树的数字数量
if(Add[rt]){
//下推标记
Add[rt<<1]+=Add[rt];
Add[rt<<1|1]+=Add[rt];
//修改子节点的Sum使之与对应的Add相对应
Sum[rt<<1]+=Add[rt]*ln;
Sum[rt<<1|1]+=Add[rt]*rn;
//清除本节点标记
Add[rt]=0;
}
}
// 然后是查询函数
int Query(int L,int R,int l,int r,int rt){//L,R表示操作区间,l,r表示当前节点区间,rt表示当前节点编号
if(L <= l && r <= R){
//在区间内,直接返回
return Sum[rt];
}
int m=(l+r)>>1;
//下推标记,否则Sum可能不正确
PushDown(rt,m-l+1,r-m);
//累计答案
int ANS=0;
if(L <= m) ANS+=Query(L,R,l,m,rt<<1);
if(R > m) ANS+=Query(L,R,m+1,r,rt<<1|1);
return ANS;
}
// 整个调用过程
//建树
Build(1,n,1);
//点修改
Update(L,C,1,n,1);
//区间修改
Update(L,R,C,1,n,1);
//区间查询
int ANS=Query(L,R,1,n,1);例题1
#include<bits/stdc++.h>
using namespace std;
#define max_n 100001
// 原始数据
int seg[max_n];
// 最大值数组
int arr[4*max_n];
// p当前节点
// l,r当前节点的区间
void create(int l,int r,int p)
{
if(l==r)
{
arr[p]=seg[l];
return;
}
int m = (l+r)>>1;
create(l,m,p<<1);
create(m+1,r,p<<1|1);
arr[p] = max(arr[p<<1],arr[p<<1|1]);
}
// 修改点
void update(int idx,int val,int l,int r,int p)
{
if(l==r&&l==idx)
{
arr[p]=val;
return;
}
int m = (l+r)>>1;
if(idx<=m) update(idx,val,l,m,p<<1);
if(idx>m) update(idx,val,m+1,r,p<<1|1);
arr[p] = max(arr[p<<1],arr[p<<1|1]);
}
// 区间查询
int query(int L,int R,int l,int r,int p)
{
if(L<=l&&R>=r) return arr[p];
int ans = -1;
int m = (l+r)>>1;
if(L<=m) ans = max(ans,query(L,R,l,m,p<<1));
if(R>m) ans = max(ans,query(L,R,m+1,r,p<<1|1));
return ans;
}
int main()
{
int n,m;
while(cin>>n>>m){
for(int i=1;i<=n;i++){
scanf("%d",&seg[i]);
}
create(1,n,1);
char order;
int a,b;
for(int i=1;i<=m;++i){
cin>>order>>a>>b;
if(order=='U'){
update(a,b,1,n,1);
}else if(order=='Q'){
if(a>b)swap(a,b);
printf("%d\n",query(a,b,1,n,1));
}
}
}
return 0;
}
例题2
也可用dp方法来做。
// 线段树总区间 [1,n]
// 线段树节点结构体
struct asd{
int left,right,mid; // 当前节点的左右区间范围和中点
int val; // 所有数据中落在当前区间的那些值构成的递增序列的最大长度
int cnt; // 组成这个最大长度的方案数
};
vector<asd>q;
class Solution {
public:
// 建树
void build(int i,int left,int right)
{
q[i].left=left;
q[i].right=right;
q[i].mid=(left+right)>>1;
q[i].cnt=q[i].val=0;
if(left==right)
return;
build(i<<1,left,q[i].mid);
build(i<<1|1,q[i].mid+1,right);
}
// 查询落在区间[left,right]内的最大长度和方案数
void query(int i, int left, int right, int &x, int &y )
{
if(q[i].left==left&&q[i].right==right)
{
x=q[i].val;
y=q[i].cnt;
return;
}
if(right<=q[i].mid)
{
query(i<<1, left, right, x, y);
return;
}
if(left>q[i].mid)
{
query(i<<1|1,left,right, x, y);
return;
}
int lx,ly;
int rx,ry;
query(i<<1,left,q[i].mid,lx,ly);
query(i<<1|1,q[i].mid+1,right, rx,ry);
if(lx==rx)
{
x=lx;
y=(ly+ry);
}
else if(lx>rx)
{
x=lx;
y=ly;
}
else
{
x=rx;
y=ry;
}
}
// 用查询的结果更新整棵树
void update(int i,int p,int x,int y)
{
if(q[i].left==q[i].right&&q[i].left==p)
{
if(x>q[i].val)
{
q[i].val=x;
q[i].cnt=y;
}
else if(x==q[i].val)
{
q[i].cnt=(q[i].cnt+y);
}
return;
}
if(p<=q[i].mid)
update(i<<1,p,x,y);
else if(p>q[i].mid)
update(i<<1|1,p,x,y);
if(q[i<<1].val==q[i<<1|1].val)
{
q[i].val=q[i<<1].val;
q[i].cnt=(q[i<<1].cnt+q[i<<1|1].cnt);
}
else if(q[i<<1].val>q[i<<1|1].val)
{
q[i].val=q[i<<1].val;
q[i].cnt=q[i<<1].cnt;
}
else
{
q[i].val=q[i<<1|1].val;
q[i].cnt=q[i<<1|1].cnt;
}
}
int findNumberOfLIS(vector<int>& nums) {
if(nums.empty()) return 0;
int n = nums.size();
// 保存原始数据
vector<int>arr;
arr.resize(n+1);
q.resize(n*4+5);
vector<int>xs;
for(int i=1;i<=n;i++)
{
arr[i] = nums[i-1];
xs.push_back(arr[i]);
}
// 排序 + 去重 + 离散化
sort(xs.begin(),xs.end());
auto e=unique(xs.begin(),xs.end());
for(int i=1;i<=n;i++)
arr[i]=lower_bound(xs.begin(), e, arr[i])-xs.begin()+1;
build(1,0,n);
// 每次查询0~arr[i-1]这个区间范围的最大值x
// 以 x+1 去更新
for(int i=1;i<=n;i++)
{
int x,y;
query(1,0,arr[i]-1,x,y);
//cout<<x<<" "<<y<<endl;
if(!y)
y=1;
update(1, arr[i] , x+1, y);
}
int x,y;
query(1,0,n,x,y);
return y;
}
};例题3
思路:每掉落一个方块,更新该方块落在的区间的最大高度。利用线段树进行区间修改和区间最大值查询。
class Solution {
public:
map<int,int>mp;
const int MX = 2009;
int mx[4 * 3000];
// 查询区间最大值
int query(int L,int R,int l,int r,int rt)
{
if(L<=l&&R>=r) return mx[rt];
int mid = (l+r)>>1;
int ans = 0;
if(mid>=L) ans = max(ans,query(L,R,l,mid,rt<<1));
if(R>mid) ans = max(ans,query(L,R,mid+1,r,rt<<1|1));
return ans;
}
// 区间修改
void update(int L,int R,int l,int r,int rt,int val)
{
if(l==r)
{
mx[rt]=val;
return;
}
int mid = (l+r)>>1;
if(mid>=L) update(L,R,l,mid,rt<<1,val);
if(R>mid) update(L,R,mid+1,r,rt<<1|1,val);
mx[rt] = max(mx[rt<<1],mx[rt<<1|1]);
}
vector<int> fallingSquares(vector<vector<int>>& positions) {
vector<int>ans;
int n = positions.size();
//memset(mx,0,sizeof(mx));
// 离散化
for(int i=0;i<n;++i)
{
mp[positions[i][0]] = 1;
mp[positions[i][0]+positions[i][1]-1] = 1;
}
int cnt = 1;
for(auto& t:mp)
t.second = cnt++;
int l;
int r;
for(int i=0;i<n;++i)
{
l = mp[positions[i][0]];
r = mp[positions[i][0]+positions[i][1]-1];
int res = query(l,r,1,2009,1);
update(l,r,1,2009,1,res+positions[i][1]);
ans.push_back(mx[1]);
}
return ans;
}
};例4 张贴海报
法一
比较容易想到的一个思路。每贴一张海报,将其长度所覆盖的那些位置的值改成这张海报的索引。最后遍历整个数组,统计不重复的海报编号数目即可。时间复杂度较高。
#include<bits/stdc++.h>
using namespace std;
const int max_n = 1e7;
int v[max_n];
int main()
{
memset(v,0,sizeof(v));
int n;
cin>>n;
int a,b;
for(int i=1;i<=n;++i)
{
cin>>a>>b;
for(int j=a;j<=b;++j)
v[j] = i;
}
set<int>se;
for(int i=1;i<max_n;++i)
if(v[i])
se.insert(v[i]);
cout<<se.size()<<endl;
return 0;
}法二 线段树
通过线段树来维护区间。最后统计不重复的海报编号数目。
#include<bits/stdc++.h>
using namespace std;
const int max_n = 1e5;
int seg[4*max_n];
// 懒惰标记,用于延迟传播
int add[4*max_n];
// 建树
void create(int l,int r,int rt)
{
if(l==r)
return;
create(l,(l+r)>>1,rt<<1);
create(((l+r)>>1)+1,r,rt<<1|1);
}
// 下推函数
void pushDown(int rt)
{
if(add[rt])
{
// 将标记下推到左右子区间
add[rt<<1] = add[rt];
add[rt<<1|1] = add[rt];
// 清除本节点的标记
add[rt] = 0;
}
}
void update(int L,int R,int l,int r,int rt,int val)
{
if(L<=l&&R>=r)
{
add[rt] = val; //子区间仍需要根据add值进行调整
return;
}
//if(l==r) // 叶子节点更新
//{
// add[rt] = val;
// return;
//}
// 下推标记
pushDown(rt);
// 继续寻找区间
int mid = (l+r)>>1;
if(L<=mid) update(L,R,l,mid,rt<<1,val);
if(R>mid) update(L,R,mid+1,r,rt<<1|1,val);
}
// 区间查询不重复的海报数目
void query(int l,int r,int rt,set<int>& se)
{
if(add[rt])
{
// if(l==r)
// {
se.insert(add[rt]);
return;
// }
// else pushDown(rt);
}
// 到叶节点了 返回
if(l==r) return;
query(l,(l+r)>>1,rt<<1,se);
query(((l+r)>>1)+1,r,rt<<1|1,se);
}
int main()
{
int n;
cin>>n;
int a,b;
vector<pair<int,int>>data;
// 辅助数组
vector<int>xs;
for(int i=1;i<=n;++i)
{
cin>>a>>b;
data.push_back({a,b});
xs.push_back(a);
xs.push_back(b);
}
// 排序
sort(xs.begin(),xs.end());
// 去重
xs.erase(unique(xs.begin(),xs.end()),xs.end());
create(1,xs.size(),1);
set<int>se;
for(int i=1;i<=n;++i)
{
// 离散化 节约线段树需要开辟的空间
int l = lower_bound(xs.begin(),xs.end(),data[i-1].first)-xs.begin()+1;
int r = lower_bound(xs.begin(),xs.end(),data[i-1].second)-xs.begin()+1;
// 更新
update(l,r,1,xs.size(),1,i);
}
// 统计数目
query(1,xs.size(),1,se);
cout<<se.size()<<endl;
return 0;
}例5
思路
直接统计每个元素右侧比它小的元素个数,O(n^2),超时。此题可有多种解法。
线段树解法:离散化操作后,求得整个序列的最小值min和最大值max,线段树的整个区间为[min,max].
从数组的最后一位开始遍历,更新线段树。若当前元素nums[i],统计nums[i]右侧比它小的元素个数,即统计落在区间[min,nums[i]-1]这个区间的数的个数。故每遇见一个数,就将包含该数的区间的计数值+1,这样在查询时候,统计落在某个区间的数的个数就转化为了区间和问题。
按照此思路,题目可变形为求每个元素左/右侧比它大/小的元素个数。
const int max_n = 1e6;
//int seg[4*max_n];
int sum[4*max_n];
class Solution {
public:
void update(int l,int r,int rt,int val)
{
if(l==r)
{
sum[rt] += 1;
return;
}
int m = (l+r)>>1;
if(val<=m)
update(l,m,rt<<1,val);
else
update(m+1,r,rt<<1|1,val);
sum[rt] = sum[rt<<1]+sum[rt<<1|1];
}
int query(int L,int R,int l,int r, int rt)
{
if(L<=l&&R>=r)
{
return sum[rt];
}
int m = (l+r)>>1;
int ans = 0;
if(L<=m)ans+=query(L,R,l,m,rt<<1);
if(R>m) ans+=query(L,R,m+1,r,rt<<1|1);
return ans;
}
vector<int> countSmaller(vector<int>& nums) {
vector<int>ans;
if(nums.empty()) return ans;
ans.resize(nums.size(),0);
memset(sum,0,sizeof(sum));
int MAX = INT_MIN;
int MIN = INT_MAX;
vector<int>xs;
for(auto t:nums)
xs.push_back(t);
sort(nums.begin(),nums.end());
// 离散化
for(int i=0;i<xs.size();++i)
xs[i] = (lower_bound(nums.begin(),nums.end(),xs[i])-nums.begin()+1);
//for(auto a:xs) cout<<" "<<a;
// 统计整个序列的最值
for(auto t:xs){
MAX = max(MAX,t);
MIN = min(MIN,t);
}
for(int i=nums.size()-1;i>=0;--i)
{
update(MIN,MAX,1,xs[i]);
if(xs[i]==MIN) continue;
int cnt = query(MIN,xs[i]-1,MIN,MAX,1);
ans[i] = cnt;
}
return ans;
}
};
京公网安备 11010502036488号