概念
假设有编号从1到n的n个点,每个点都存了一些信息,用[L,R]表示下标从L到R的这些点。线段树的用处就是,对编号连续的一些点进行修改或者统计操作,修改和统计的复杂度都是O(log(n)).线段树的原理,就是,将[1,n]分解成若干特定的子区间(数量不超过4*n),然后,将每个区间[L,R]都分解为少量特定的子区间,通过对这些少量子区间的修改或者统计,来实现快速对[L,R]的修改或者统计。
要点
用线段树统计的东西,必须符合区间加法,否则,不可能通过分成的子区间来得到[L,R]的统计结果。
符合区间加法的例子:
数字之和——总数字之和 = 左区间数字之和 + 右区间数字之和
最大公因数(GCD)——总GCD = gcd( 左区间GCD , 右区间GCD )
最大值——总最大值=max(左区间最大值,右区间最大值)
不符合区间加法的例子:
众数——只知道左右区间的众数,没法求总区间的众数
01序列的最长连续零——只知道左右区间的最长连续零,没法知道总的最长连续零
线段树存储结构
线段树是用数组来模拟树形结构,对于每一个节点R ,左子节点为 2R (一般写作R<<1)右子节点为 2R+1(一般写作R<<1|1). 然后以1为根节点,所以,整体的统计信息是存在节点1中的。线段树需要的数组元素个数是2^log(n)+1,一般都开4倍空间,比如:int A[n<<2];
线段树基本实现(递归方法)
这里以区间求和为例。
// 定义 #define maxn 100007 //元素总个数 int Sum[maxn<<2],Add[maxn<<2];//Sum求和,Add为懒惰标记 int A[maxn],n;//存原数组数据,下标1到n // 建树 //PushUp函数更新节点信息 ,这里是求和 void PushUp(int rt){Sum[rt]=Sum[rt<<1]+Sum[rt<<1|1];} //Build函数建树 void Build(int l,int r,int rt){ //l,r表示当前节点区间,rt表示当前节点编号 if(l==r) {//若到达叶节点 Sum[rt]=A[l];//储存数组值 return; } int m=(l+r)>>1; //左右递归 Build(l,m,rt<<1); Build(m+1,r,rt<<1|1); //更新信息 PushUp(rt); } // 点修改 // 假定 A[L]+=C void Update(int L,int C,int l,int r,int rt){ //l,r表示当前节点区间,rt表示当前节点编号 if(l==r){//到叶节点,修改 Sum[rt]+=C; return; } int m=(l+r)>>1; //根据条件判断往左子树调用还是往右 if(L <= m) Update(L,C,l,m,rt<<1); else Update(L,C,m+1,r,rt<<1|1); PushUp(rt);//子节点更新了,所以本节点也需要更新信息 } // 区间修改 // 假定A[L,R]+=C void Update(int L,int R,int C,int l,int r,int rt){//L,R表示操作区间,l,r表示当前节点区间,rt表示当前节点编号 if(L <= l && r <= R){//如果本区间完全在操作区间[L,R]以内 Sum[rt]+=C*(r-l+1);//更新数字和,向上保持正确 Add[rt]+=C;//增加Add标记,表示本区间的Sum正确,子区间的Sum仍需要根据Add的值来调整 return ; } int m=(l+r)>>1; PushDown(rt,m-l+1,r-m);//下推标记 //这里判断左右子树跟[L,R]有无交集,有交集才递归 if(L <= m) Update(L,R,C,l,m,rt<<1); if(R > m) Update(L,R,C,m+1,r,rt<<1|1); PushUp(rt);//更新本节点信息 } // 区间查询 询问A[L,R]这个范围的和 // 首先是下推函数 void PushDown(int rt,int ln,int rn){ //ln,rn为左子树,右子树的数字数量 if(Add[rt]){ //下推标记 Add[rt<<1]+=Add[rt]; Add[rt<<1|1]+=Add[rt]; //修改子节点的Sum使之与对应的Add相对应 Sum[rt<<1]+=Add[rt]*ln; Sum[rt<<1|1]+=Add[rt]*rn; //清除本节点标记 Add[rt]=0; } } // 然后是查询函数 int Query(int L,int R,int l,int r,int rt){//L,R表示操作区间,l,r表示当前节点区间,rt表示当前节点编号 if(L <= l && r <= R){ //在区间内,直接返回 return Sum[rt]; } int m=(l+r)>>1; //下推标记,否则Sum可能不正确 PushDown(rt,m-l+1,r-m); //累计答案 int ANS=0; if(L <= m) ANS+=Query(L,R,l,m,rt<<1); if(R > m) ANS+=Query(L,R,m+1,r,rt<<1|1); return ANS; } // 整个调用过程 //建树 Build(1,n,1); //点修改 Update(L,C,1,n,1); //区间修改 Update(L,R,C,1,n,1); //区间查询 int ANS=Query(L,R,1,n,1);
例题1
#include<bits/stdc++.h> using namespace std; #define max_n 100001 // 原始数据 int seg[max_n]; // 最大值数组 int arr[4*max_n]; // p当前节点 // l,r当前节点的区间 void create(int l,int r,int p) { if(l==r) { arr[p]=seg[l]; return; } int m = (l+r)>>1; create(l,m,p<<1); create(m+1,r,p<<1|1); arr[p] = max(arr[p<<1],arr[p<<1|1]); } // 修改点 void update(int idx,int val,int l,int r,int p) { if(l==r&&l==idx) { arr[p]=val; return; } int m = (l+r)>>1; if(idx<=m) update(idx,val,l,m,p<<1); if(idx>m) update(idx,val,m+1,r,p<<1|1); arr[p] = max(arr[p<<1],arr[p<<1|1]); } // 区间查询 int query(int L,int R,int l,int r,int p) { if(L<=l&&R>=r) return arr[p]; int ans = -1; int m = (l+r)>>1; if(L<=m) ans = max(ans,query(L,R,l,m,p<<1)); if(R>m) ans = max(ans,query(L,R,m+1,r,p<<1|1)); return ans; } int main() { int n,m; while(cin>>n>>m){ for(int i=1;i<=n;i++){ scanf("%d",&seg[i]); } create(1,n,1); char order; int a,b; for(int i=1;i<=m;++i){ cin>>order>>a>>b; if(order=='U'){ update(a,b,1,n,1); }else if(order=='Q'){ if(a>b)swap(a,b); printf("%d\n",query(a,b,1,n,1)); } } } return 0; }
例题2
也可用dp方法来做。
// 线段树总区间 [1,n] // 线段树节点结构体 struct asd{ int left,right,mid; // 当前节点的左右区间范围和中点 int val; // 所有数据中落在当前区间的那些值构成的递增序列的最大长度 int cnt; // 组成这个最大长度的方案数 }; vector<asd>q; class Solution { public: // 建树 void build(int i,int left,int right) { q[i].left=left; q[i].right=right; q[i].mid=(left+right)>>1; q[i].cnt=q[i].val=0; if(left==right) return; build(i<<1,left,q[i].mid); build(i<<1|1,q[i].mid+1,right); } // 查询落在区间[left,right]内的最大长度和方案数 void query(int i, int left, int right, int &x, int &y ) { if(q[i].left==left&&q[i].right==right) { x=q[i].val; y=q[i].cnt; return; } if(right<=q[i].mid) { query(i<<1, left, right, x, y); return; } if(left>q[i].mid) { query(i<<1|1,left,right, x, y); return; } int lx,ly; int rx,ry; query(i<<1,left,q[i].mid,lx,ly); query(i<<1|1,q[i].mid+1,right, rx,ry); if(lx==rx) { x=lx; y=(ly+ry); } else if(lx>rx) { x=lx; y=ly; } else { x=rx; y=ry; } } // 用查询的结果更新整棵树 void update(int i,int p,int x,int y) { if(q[i].left==q[i].right&&q[i].left==p) { if(x>q[i].val) { q[i].val=x; q[i].cnt=y; } else if(x==q[i].val) { q[i].cnt=(q[i].cnt+y); } return; } if(p<=q[i].mid) update(i<<1,p,x,y); else if(p>q[i].mid) update(i<<1|1,p,x,y); if(q[i<<1].val==q[i<<1|1].val) { q[i].val=q[i<<1].val; q[i].cnt=(q[i<<1].cnt+q[i<<1|1].cnt); } else if(q[i<<1].val>q[i<<1|1].val) { q[i].val=q[i<<1].val; q[i].cnt=q[i<<1].cnt; } else { q[i].val=q[i<<1|1].val; q[i].cnt=q[i<<1|1].cnt; } } int findNumberOfLIS(vector<int>& nums) { if(nums.empty()) return 0; int n = nums.size(); // 保存原始数据 vector<int>arr; arr.resize(n+1); q.resize(n*4+5); vector<int>xs; for(int i=1;i<=n;i++) { arr[i] = nums[i-1]; xs.push_back(arr[i]); } // 排序 + 去重 + 离散化 sort(xs.begin(),xs.end()); auto e=unique(xs.begin(),xs.end()); for(int i=1;i<=n;i++) arr[i]=lower_bound(xs.begin(), e, arr[i])-xs.begin()+1; build(1,0,n); // 每次查询0~arr[i-1]这个区间范围的最大值x // 以 x+1 去更新 for(int i=1;i<=n;i++) { int x,y; query(1,0,arr[i]-1,x,y); //cout<<x<<" "<<y<<endl; if(!y) y=1; update(1, arr[i] , x+1, y); } int x,y; query(1,0,n,x,y); return y; } };
例题3
思路:每掉落一个方块,更新该方块落在的区间的最大高度。利用线段树进行区间修改和区间最大值查询。
class Solution { public: map<int,int>mp; const int MX = 2009; int mx[4 * 3000]; // 查询区间最大值 int query(int L,int R,int l,int r,int rt) { if(L<=l&&R>=r) return mx[rt]; int mid = (l+r)>>1; int ans = 0; if(mid>=L) ans = max(ans,query(L,R,l,mid,rt<<1)); if(R>mid) ans = max(ans,query(L,R,mid+1,r,rt<<1|1)); return ans; } // 区间修改 void update(int L,int R,int l,int r,int rt,int val) { if(l==r) { mx[rt]=val; return; } int mid = (l+r)>>1; if(mid>=L) update(L,R,l,mid,rt<<1,val); if(R>mid) update(L,R,mid+1,r,rt<<1|1,val); mx[rt] = max(mx[rt<<1],mx[rt<<1|1]); } vector<int> fallingSquares(vector<vector<int>>& positions) { vector<int>ans; int n = positions.size(); //memset(mx,0,sizeof(mx)); // 离散化 for(int i=0;i<n;++i) { mp[positions[i][0]] = 1; mp[positions[i][0]+positions[i][1]-1] = 1; } int cnt = 1; for(auto& t:mp) t.second = cnt++; int l; int r; for(int i=0;i<n;++i) { l = mp[positions[i][0]]; r = mp[positions[i][0]+positions[i][1]-1]; int res = query(l,r,1,2009,1); update(l,r,1,2009,1,res+positions[i][1]); ans.push_back(mx[1]); } return ans; } };
例4 张贴海报
法一
比较容易想到的一个思路。每贴一张海报,将其长度所覆盖的那些位置的值改成这张海报的索引。最后遍历整个数组,统计不重复的海报编号数目即可。时间复杂度较高。
#include<bits/stdc++.h> using namespace std; const int max_n = 1e7; int v[max_n]; int main() { memset(v,0,sizeof(v)); int n; cin>>n; int a,b; for(int i=1;i<=n;++i) { cin>>a>>b; for(int j=a;j<=b;++j) v[j] = i; } set<int>se; for(int i=1;i<max_n;++i) if(v[i]) se.insert(v[i]); cout<<se.size()<<endl; return 0; }
法二 线段树
通过线段树来维护区间。最后统计不重复的海报编号数目。
#include<bits/stdc++.h> using namespace std; const int max_n = 1e5; int seg[4*max_n]; // 懒惰标记,用于延迟传播 int add[4*max_n]; // 建树 void create(int l,int r,int rt) { if(l==r) return; create(l,(l+r)>>1,rt<<1); create(((l+r)>>1)+1,r,rt<<1|1); } // 下推函数 void pushDown(int rt) { if(add[rt]) { // 将标记下推到左右子区间 add[rt<<1] = add[rt]; add[rt<<1|1] = add[rt]; // 清除本节点的标记 add[rt] = 0; } } void update(int L,int R,int l,int r,int rt,int val) { if(L<=l&&R>=r) { add[rt] = val; //子区间仍需要根据add值进行调整 return; } //if(l==r) // 叶子节点更新 //{ // add[rt] = val; // return; //} // 下推标记 pushDown(rt); // 继续寻找区间 int mid = (l+r)>>1; if(L<=mid) update(L,R,l,mid,rt<<1,val); if(R>mid) update(L,R,mid+1,r,rt<<1|1,val); } // 区间查询不重复的海报数目 void query(int l,int r,int rt,set<int>& se) { if(add[rt]) { // if(l==r) // { se.insert(add[rt]); return; // } // else pushDown(rt); } // 到叶节点了 返回 if(l==r) return; query(l,(l+r)>>1,rt<<1,se); query(((l+r)>>1)+1,r,rt<<1|1,se); } int main() { int n; cin>>n; int a,b; vector<pair<int,int>>data; // 辅助数组 vector<int>xs; for(int i=1;i<=n;++i) { cin>>a>>b; data.push_back({a,b}); xs.push_back(a); xs.push_back(b); } // 排序 sort(xs.begin(),xs.end()); // 去重 xs.erase(unique(xs.begin(),xs.end()),xs.end()); create(1,xs.size(),1); set<int>se; for(int i=1;i<=n;++i) { // 离散化 节约线段树需要开辟的空间 int l = lower_bound(xs.begin(),xs.end(),data[i-1].first)-xs.begin()+1; int r = lower_bound(xs.begin(),xs.end(),data[i-1].second)-xs.begin()+1; // 更新 update(l,r,1,xs.size(),1,i); } // 统计数目 query(1,xs.size(),1,se); cout<<se.size()<<endl; return 0; }
例5
思路
直接统计每个元素右侧比它小的元素个数,O(n^2),超时。此题可有多种解法。
线段树解法:离散化操作后,求得整个序列的最小值min和最大值max,线段树的整个区间为[min,max].
从数组的最后一位开始遍历,更新线段树。若当前元素nums[i],统计nums[i]右侧比它小的元素个数,即统计落在区间[min,nums[i]-1]这个区间的数的个数。故每遇见一个数,就将包含该数的区间的计数值+1,这样在查询时候,统计落在某个区间的数的个数就转化为了区间和问题。
按照此思路,题目可变形为求每个元素左/右侧比它大/小的元素个数。
const int max_n = 1e6; //int seg[4*max_n]; int sum[4*max_n]; class Solution { public: void update(int l,int r,int rt,int val) { if(l==r) { sum[rt] += 1; return; } int m = (l+r)>>1; if(val<=m) update(l,m,rt<<1,val); else update(m+1,r,rt<<1|1,val); sum[rt] = sum[rt<<1]+sum[rt<<1|1]; } int query(int L,int R,int l,int r, int rt) { if(L<=l&&R>=r) { return sum[rt]; } int m = (l+r)>>1; int ans = 0; if(L<=m)ans+=query(L,R,l,m,rt<<1); if(R>m) ans+=query(L,R,m+1,r,rt<<1|1); return ans; } vector<int> countSmaller(vector<int>& nums) { vector<int>ans; if(nums.empty()) return ans; ans.resize(nums.size(),0); memset(sum,0,sizeof(sum)); int MAX = INT_MIN; int MIN = INT_MAX; vector<int>xs; for(auto t:nums) xs.push_back(t); sort(nums.begin(),nums.end()); // 离散化 for(int i=0;i<xs.size();++i) xs[i] = (lower_bound(nums.begin(),nums.end(),xs[i])-nums.begin()+1); //for(auto a:xs) cout<<" "<<a; // 统计整个序列的最值 for(auto t:xs){ MAX = max(MAX,t); MIN = min(MIN,t); } for(int i=nums.size()-1;i>=0;--i) { update(MIN,MAX,1,xs[i]); if(xs[i]==MIN) continue; int cnt = query(MIN,xs[i]-1,MIN,MAX,1); ans[i] = cnt; } return ans; } };