引言
树套树,顾名思义,就是要将两种或多种树形数据结构结合起来,解决一些单独无法解决的问题。
如果说要解决区间上的问题,如最大值,区间修改等,肯定会想到线段树。
但是线段树不能查询第k大,不能查询一个数在区间的排名,自然也不能查询前驱和后继。
平衡树可以解决查询排名、前驱、后继等问题,但其不能限定区间。
文艺平衡树中有操作可以把区间锁定在一个结点的子树,问题是只能通过翻转左右子树,来实现区间翻转。
既然单独无法解决这个问题,那就将两种树形数据结构结合起来。
原理
很多人都对树套树望而生畏,包括我。。。
以前只知道通过树套树可以解决的问题,但没有敲过
经常听到队友说几个线段树,再用一个主席树维护什么什么的,但其实原理不难,只要懂树套树中的这两种树。
举个例子:
如图这是个线段树
假设这个序列是: 5 2 3 4 5 7 8 9 3 1 (随便写的)
现在我要查2-7区间中第5个数即5在这个区间排第几小,
[3-7]区间即:[3],[4-5],[6-7]
第几小即计算有多少个比它小,然后加一
[3]:1个
[4-5]:1个
[6-7]:0个
所以他是第3小的。
将每个子区间得到的答案求和利用的是线段树
而中间每个区间查询有多少比它小利用的是平衡树Splay
即线段树的每个结点建立一个Splay
有人会怀疑空间复杂度不够,如果把Splay封装,每个Splay都是\(N\)的大小必然不够,我们不需要事先开辟那么多空间来建Splay
代码中是:(开局一个root,然后记录每个线段树结点的root就行了)
void build(int p,int l,int r){
t[p].l = l,t[p].r = r;
//线段树每个结点建立一个splay
sp.ins(t[p].rt,-inf);
sp.ins(t[p].rt,inf);
for(int i = l;i <= r;++i){
sp.ins(t[p].rt,arr[i]);
}
if(l == r){ t[p].mx = t[p].mn = arr[l];return; }
int mid = (l+r) >> 1;
build(p<<1,l,mid);
build(p<<1|1,mid + 1,r);
pushUp(p);
}
这里有个问题,root会发生变化,所以线段树结点中定义的root并不是一成不变的,这需要用到引用,即传地址。
还有就是要插入两个无穷大结点,来解决不存在的情况。
应用-模板题
- 查询k在区间内的排名
- 查询区间内排名为k的值
- 修改某一位值上的数值
- 查询k在区间内的前驱(前驱定义为严格小于x,且最大的数,若不存在输出-2147483647)
- 查询k在区间内的后继(后继定义为严格大于x,且最小的数,若不存在输出2147483647)
先复制上封装好的Splay
struct Splay {
int get(int x) {return s[s[x].fa].ch[1] == x;}
void Clear(int x) {
s[x].fa = s[x].ch[0] = s[x].ch[1] = s[x].sz = s[x].val =0;
}
void maintain(int x){
s[x].sz = s[s[x].ch[0]].sz + s[s[x].ch[1]].sz + s[x].cnt;
}
void Rorate(int x){
int y = s[x].fa, z = s[y].fa, chk = get(x);
s[y].ch[chk] = s[x].ch[chk ^ 1];
s[s[x].ch[chk ^ 1]].fa = y;
s[y].fa = x;
s[x].ch[chk ^ 1] = y;
s[x].fa =z;
if(z) s[z].ch[s[z].ch[1] == y] = x;
maintain(y);
maintain(x);
}
void splay(int &rt,int x,int y){
for(int f = s[x].fa;f != y;Rorate(x),f=s[x].fa){
if(s[f].fa != y) Rorate(get(x) == get(f) ? f : x);
}
if(y==0) rt = x;
}
void ins(int &root ,int val){
if(!root) {
root = ++tot;
s[root].val = val;
s[root].cnt++;
maintain(root);
return ;
}
int f = 0, x = root;
while(true){
if(s[x].val == val){
s[x].cnt ++;
maintain(x);
maintain(f);
splay(root,x,0);
return;
}
f = x;
x = s[x].ch[s[x].val < val];
if(!x) {
s[++tot].val = val;
s[tot].cnt = 1;
s[tot].fa = f;
s[f].ch[s[f].val < val] = tot;
maintain(tot);
maintain(f);
splay(root,tot,0);
return ;
}
}
}
inline int Find(int &rt,int k) {
int res = 0,now = rt;
while(true) {
if(k<s[now].val) {
now = s[now].ch[0];
}else {
//否则加上右子树的个数
res += s[s[now].ch[0]].sz;
//中序遍历,如果找到这个节点返回res+1
if(k == s[now].val) {
splay(rt,now,0);
return res + 1;
}
res += s[now].cnt;
now = s[now].ch[1];
}
}
}
int getPre(int rt){
int now = s[rt].ch[0];
while (s[now].ch[1]) now = s[now].ch[1];
return now;
}
int getNxt(int rt){
int now = s[rt].ch[1];
while (s[now].ch[0]) now = s[now].ch[0];
return now;
}
inline void del(int &rt,int k){
Find(rt,k);//先让该点成为根节点
if(s[rt].cnt > 1) {//如果大于1,不需要删除节点
s[rt].cnt--;
maintain(rt);
return;
}
//如果只有一个点
if(!s[rt].ch[0] && !s[rt].ch[1]){
Clear(rt);
rt = 0;
return;
}
//没有左儿子,让右儿子成为根节点
if(!s[rt].ch[0]){
int tmp = rt;
rt = s[rt].ch[1];
s[rt].fa=0;
Clear(tmp);
return;
}
//没有右儿子,让左儿子成为根节点
if(!s[rt].ch[1]){
int tmp = rt;
rt = s[rt].ch[0];
s[rt].fa = 0;
Clear(tmp);
return;
}
//有左右儿子,让前驱成为根节点
int x = getPre(rt) , now = rt;
splay(rt,x,0);
s[s[now].ch[1]].fa = x;
s[x].ch[1] = s[now].ch[1];
Clear(now);
maintain(rt);
}
}sp;
- 问题1之前提到了,就是在Splay中插入这个结点,然后返回这个结点的左儿子的Size就行,记得减去无穷大的那个点。
int query_order(int p,int l,int r,int val){
//查询顺序,就是查有多少个比他小
if(l <= t[p].l && t[p].r <= r){
sp.ins(t[p].rt,val);
int res = s[s[t[p].rt].ch[0]].sz-1;
sp.del(t[p].rt,val);
return res;
}
int mid = (t[p].l + t[p].r) >> 1,res = 0;
if(l <= mid) res += query_order(p << 1,l,r,val);
if(mid < r) res += query_order(p << 1|1,l,r,val);
return res;
}
- 问题2求排名k的值,这需要用到二分,二分check函数就是问题1的
query_order
,在区间权值范围内二分,权值越大排名越大,就是在单调递增区间中查询小于k的数的最大值(因为有一个无穷小结点,所以不能小于等于)。二分模板也很明显:
int query_number(int L,int R,int val){
int l = 1,r = getMax(1,L,R) ,mid,tmp;
while(l < r){
mid = (l + r + 1)>>1;
tmp = query_order(1,L,R,mid);
if(tmp < val){
l = mid;
}else{
r = mid - 1;
}
}
return l;
}
- 问题3是修改,这个不难,这个点所在的所有线段树结点都要删除该点在Splay树上的结点,然后加入新值。
void modify(int p,int pos,int val){
sp.del(t[p].rt,arr[pos]);
sp.ins(t[p].rt,val);
if(t[p].l == t[p].r){
t[p].mx = val;
t[p].mn = val;
arr[pos] = val;
return;
}
int mid = (t[p].l + t[p].r) >> 1;
if(pos <= mid) modify(p << 1,pos,val);
if(pos > mid) modify(p << 1 | 1,pos,val);
pushUp(p);
}
- 问题4,查询前驱,即查询每个线段树区间最大的比该数小的数,最后取个最大值。5同理
int query_Pre(int p,int l,int r,int val){
if(l <= t[p].l && r >= t[p].r){
sp.ins(t[p].rt,val);
int res = s[sp.getPre(t[p].rt)].val;
sp.del(t[p].rt,val);
return res;
}
int res = -inf,mid = (t[p].l + t[p].r) >> 1;
if(l <= mid) res = max(res,query_Pre(p << 1,l,r,val));
if(r > mid) res = max(res,query_Pre(p << 1|1,l,r,val));
return res;
}
int query_Nxt(int p,int l,int r,int val){
if(l <= t[p].l && r >= t[p].r){
sp.ins(t[p].rt,val);
int res = s[sp.getNxt(t[p].rt)].val;
sp.del(t[p].rt,val);
return res;
}
int res = inf,mid = (t[p].l + t[p].r) >> 1;
if(l <= mid) res = min(res,query_Nxt(p << 1,l,r,val));
if(r > mid) res = min(res,query_Nxt(p << 1|1,l,r,val));
return res;
}
- 中途为了优化二分(也没什么用),还加了线段树查询最大值和最小值的
int getMax(int p,int l,int r){
if(l <= t[p].l && t[p].r <=r) return t[p].mx;
int mid = (t[p].l + t[p].r) >> 1,res = -inf;
if(l <= mid) res = max(res,getMax(p << 1,l,r));
if(mid < r) res = max(res,getMax(p << 1 | 1,l,r));
return res;
}
int getMin(int p,int l,int r){
if(l <= t[p].l && t[p].r <= r) return t[p].mx;
int mid = (t[p].l + t[p].r) >> 1 ,res = inf;
if(l <= mid) res = min(res,getMin(p << 1,l,r));
if(mid < r) res = min(res,getMin(p << 1|1,l,r));
return res;
}
完整代码
#pragma GCC optimize(2)
#pragma GCC optimize(3,"Ofast","inline")
#include<bits/stdc++.h>
using namespace std;
#define ll long long
const int N = 1e7+7;
const int inf = 2147483647;
int tot;//节点个数
struct node {
int fa;//父亲节点
int ch[2];//子节点
int val;//权值
int sz;//子树大小
int cnt;
}s[N];
struct Tree{
int rt,l,r,mx,mn;
}t[N];
int arr[N];
struct Splay {
int get(int x) {return s[s[x].fa].ch[1] == x;}
void Clear(int x) {
s[x].fa = s[x].ch[0] = s[x].ch[1] = s[x].sz = s[x].val =0;
}
void maintain(int x){
s[x].sz = s[s[x].ch[0]].sz + s[s[x].ch[1]].sz + s[x].cnt;
}
void Rorate(int x){
int y = s[x].fa, z = s[y].fa, chk = get(x);
s[y].ch[chk] = s[x].ch[chk ^ 1];
s[s[x].ch[chk ^ 1]].fa = y;
s[y].fa = x;
s[x].ch[chk ^ 1] = y;
s[x].fa =z;
if(z) s[z].ch[s[z].ch[1] == y] = x;
maintain(y);
maintain(x);
}
void splay(int &rt,int x,int y){
for(int f = s[x].fa;f != y;Rorate(x),f=s[x].fa){
if(s[f].fa != y) Rorate(get(x) == get(f) ? f : x);
}
if(y==0) rt = x;
}
void ins(int &root ,int val){
if(!root) {
root = ++tot;
s[root].val = val;
s[root].cnt++;
maintain(root);
return ;
}
int f = 0, x = root;
while(true){
if(s[x].val == val){
s[x].cnt ++;
maintain(x);
maintain(f);
splay(root,x,0);
return;
}
f = x;
x = s[x].ch[s[x].val < val];
if(!x) {
s[++tot].val = val;
s[tot].cnt = 1;
s[tot].fa = f;
s[f].ch[s[f].val < val] = tot;
maintain(tot);
maintain(f);
splay(root,tot,0);
return ;
}
}
}
inline int Find(int &rt,int k) {
int res = 0,now = rt;
while(true) {
if(k<s[now].val) {
now = s[now].ch[0];
}else {
//否则加上右子树的个数
res += s[s[now].ch[0]].sz;
//中序遍历,如果找到这个节点返回res+1
if(k == s[now].val) {
splay(rt,now,0);
return res + 1;
}
res += s[now].cnt;
now = s[now].ch[1];
}
}
}
int getPre(int rt){
int now = s[rt].ch[0];
while (s[now].ch[1]) now = s[now].ch[1];
return now;
}
int getNxt(int rt){
int now = s[rt].ch[1];
while (s[now].ch[0]) now = s[now].ch[0];
return now;
}
inline void del(int &rt,int k){
Find(rt,k);//先让该点成为根节点
if(s[rt].cnt > 1) {//如果大于1,不需要删除节点
s[rt].cnt--;
maintain(rt);
return;
}
//如果只有一个点
if(!s[rt].ch[0] && !s[rt].ch[1]){
Clear(rt);
rt = 0;
return;
}
//没有左儿子,让右儿子成为根节点
if(!s[rt].ch[0]){
int tmp = rt;
rt = s[rt].ch[1];
s[rt].fa=0;
Clear(tmp);
return;
}
//没有右儿子,让左儿子成为根节点
if(!s[rt].ch[1]){
int tmp = rt;
rt = s[rt].ch[0];
s[rt].fa = 0;
Clear(tmp);
return;
}
//有左右儿子,让前驱成为根节点
int x = getPre(rt) , now = rt;
splay(rt,x,0);
s[s[now].ch[1]].fa = x;
s[x].ch[1] = s[now].ch[1];
Clear(now);
maintain(rt);
}
}sp;
void pushUp(int x){
t[x].mx = max(t[x<<1].mx,t[x<<1|1].mx);
t[x].mn = min(t[x<<1].mn,t[x<<1|1].mn);
}
void build(int p,int l,int r){
t[p].l = l,t[p].r = r;
//线段树每个结点建立一个splay
sp.ins(t[p].rt,-inf);
sp.ins(t[p].rt,inf);
for(int i = l;i <= r;++i){
sp.ins(t[p].rt,arr[i]);
}
if(l == r){ t[p].mx = t[p].mn = arr[l];return; }
int mid = (l+r) >> 1;
build(p<<1,l,mid);
build(p<<1|1,mid + 1,r);
pushUp(p);
}
int getMax(int p,int l,int r){
if(l <= t[p].l && t[p].r <=r) return t[p].mx;
int mid = (t[p].l + t[p].r) >> 1,res = -inf;
if(l <= mid) res = max(res,getMax(p << 1,l,r));
if(mid < r) res = max(res,getMax(p << 1 | 1,l,r));
return res;
}
int getMin(int p,int l,int r){
if(l <= t[p].l && t[p].r <= r) return t[p].mx;
int mid = (t[p].l + t[p].r) >> 1 ,res = inf;
if(l <= mid) res = min(res,getMin(p << 1,l,r));
if(mid < r) res = min(res,getMin(p << 1|1,l,r));
return res;
}
int query_order(int p,int l,int r,int val){
//查询顺序,就是查有多少个比他小
if(l <= t[p].l && t[p].r <= r){
sp.ins(t[p].rt,val);
int res = s[s[t[p].rt].ch[0]].sz-1;
sp.del(t[p].rt,val);
return res;
}
int mid = (t[p].l + t[p].r) >> 1,res = 0;
if(l <= mid) res += query_order(p << 1,l,r,val);
if(mid < r) res += query_order(p << 1|1,l,r,val);
return res;
}
void modify(int p,int pos,int val){
sp.del(t[p].rt,arr[pos]);
sp.ins(t[p].rt,val);
if(t[p].l == t[p].r){
t[p].mx = val;
t[p].mn = val;
arr[pos] = val;
return;
}
int mid = (t[p].l + t[p].r) >> 1;
if(pos <= mid) modify(p << 1,pos,val);
if(pos > mid) modify(p << 1 | 1,pos,val);
pushUp(p);
}
int query_Pre(int p,int l,int r,int val){
if(l <= t[p].l && r >= t[p].r){
sp.ins(t[p].rt,val);
int res = s[sp.getPre(t[p].rt)].val;
sp.del(t[p].rt,val);
return res;
}
int res = -inf,mid = (t[p].l + t[p].r) >> 1;
if(l <= mid) res = max(res,query_Pre(p << 1,l,r,val));
if(r > mid) res = max(res,query_Pre(p << 1|1,l,r,val));
return res;
}
int query_Nxt(int p,int l,int r,int val){
if(l <= t[p].l && r >= t[p].r){
sp.ins(t[p].rt,val);
int res = s[sp.getNxt(t[p].rt)].val;
sp.del(t[p].rt,val);
return res;
}
int res = inf,mid = (t[p].l + t[p].r) >> 1;
if(l <= mid) res = min(res,query_Nxt(p << 1,l,r,val));
if(r > mid) res = min(res,query_Nxt(p << 1|1,l,r,val));
return res;
}
int query_number(int L,int R,int val){
int l = 1,r = getMax(1,L,R) ,mid,tmp;
while(l < r){
mid = (l + r + 1)>>1;
tmp = query_order(1,L,R,mid);
if(tmp < val){
l = mid;
}else{
r = mid - 1;
}
}
return l;
}
int main(){
int n,q,op,l,r,pos;
scanf("%d%d",&n,&q);
for(int i=1;i<=n;++i) scanf("%d",&arr[i]);
build(1,1,n);
while(q--){
scanf("%d",&op);
if(op == 1){
scanf("%d%d%d",&l,&r,&pos);
printf("%d\n",query_order(1,l,r,pos)+1);
}else if(op == 2){
scanf("%d%d%d",&l,&r,&pos);
printf("%d\n",query_number(l,r,pos));
}else if(op == 3){
scanf("%d%d",&l,&pos);
modify(1,l,pos);
}else if(op == 4){
scanf("%d%d%d",&l,&r,&pos);
printf("%d\n",query_Pre(1,l,r,pos));
}else if(op == 5){
scanf("%d%d%d",&l,&r,&pos);
printf("%d\n",query_Nxt(1,l,r,pos));
}
}
return 0;
}
代码不加O2优化会超时,如果要优化的话,可以加个输入输出挂。
后记
博客两周年快乐。
这是第一篇博客https://www.cnblogs.com/smallocean/p/8525932.html:2018.3.7
发现自己留下的东西都可以当作时间胶囊,等未来某天翻看的时候,仿佛能看到那个时候的自己。