引言

树套树,顾名思义,就是要将两种或多种树形数据结构结合起来,解决一些单独无法解决的问题。

如果说要解决区间上的问题,如最大值,区间修改等,肯定会想到线段树

但是线段树不能查询第k大,不能查询一个数在区间的排名,自然也不能查询前驱和后继。

平衡树可以解决查询排名、前驱、后继等问题,但其不能限定区间。

文艺平衡树中有操作可以把区间锁定在一个结点的子树,问题是只能通过翻转左右子树,来实现区间翻转。

既然单独无法解决这个问题,那就将两种树形数据结构结合起来。

原理

很多人都对树套树望而生畏,包括我。。。

以前只知道通过树套树可以解决的问题,但没有敲过

经常听到队友说几个线段树,再用一个主席树维护什么什么的,但其实原理不难,只要懂树套树中的这两种树。

举个例子:

如图这是个线段树

假设这个序列是: 5 2 3 4 5 7 8 9 3 1 (随便写的)

现在我要查2-7区间中第5个数即5在这个区间排第几小,

[3-7]区间即:[3],[4-5],[6-7]

第几小即计算有多少个比它小,然后加一

[3]:1个

[4-5]:1个

[6-7]:0个

所以他是第3小的。

将每个子区间得到的答案求和利用的是线段树

而中间每个区间查询有多少比它小利用的是平衡树Splay

线段树的每个结点建立一个Splay

有人会怀疑空间复杂度不够,如果把Splay封装,每个Splay都是\(N\)的大小必然不够,我们不需要事先开辟那么多空间来建Splay

代码中是:(开局一个root,然后记录每个线段树结点的root就行了)

void build(int p,int l,int r){
    t[p].l = l,t[p].r = r;
    //线段树每个结点建立一个splay
    sp.ins(t[p].rt,-inf);
    sp.ins(t[p].rt,inf);
    for(int i = l;i <= r;++i){
        sp.ins(t[p].rt,arr[i]);
    }
    if(l == r){ t[p].mx = t[p].mn = arr[l];return; }
    int mid = (l+r) >> 1;
    build(p<<1,l,mid);
    build(p<<1|1,mid + 1,r);
    pushUp(p);
}

这里有个问题,root会发生变化,所以线段树结点中定义的root并不是一成不变的,这需要用到引用,即传地址

还有就是要插入两个无穷大结点,来解决不存在的情况。

应用-模板题

  1. 查询k在区间内的排名
  2. 查询区间内排名为k的值
  3. 修改某一位值上的数值
  4. 查询k在区间内的前驱(前驱定义为严格小于x,且最大的数,若不存在输出-2147483647)
  5. 查询k在区间内的后继(后继定义为严格大于x,且最小的数,若不存在输出2147483647)

先复制上封装好的Splay

struct Splay {
     int get(int x) {return s[s[x].fa].ch[1] == x;}

     void Clear(int x) {
         s[x].fa = s[x].ch[0] = s[x].ch[1] = s[x].sz = s[x].val =0;
     }

     void maintain(int x){
         s[x].sz = s[s[x].ch[0]].sz + s[s[x].ch[1]].sz + s[x].cnt;
     }

     void Rorate(int x){
         int y = s[x].fa, z = s[y].fa, chk = get(x);

         s[y].ch[chk] = s[x].ch[chk ^ 1];
         s[s[x].ch[chk ^ 1]].fa = y;

         s[y].fa = x;
         s[x].ch[chk ^ 1] = y;

         s[x].fa =z;
         if(z) s[z].ch[s[z].ch[1] == y] = x;

         maintain(y);
         maintain(x);
     }

     void splay(int &rt,int x,int y){
         for(int f = s[x].fa;f != y;Rorate(x),f=s[x].fa){
             if(s[f].fa != y) Rorate(get(x) == get(f) ? f : x);
         }
         if(y==0) rt = x;
     }

     void ins(int &root ,int val){
         if(!root) {
             root = ++tot;
             s[root].val = val;
             s[root].cnt++;
             maintain(root);
             return ;
         }

         int f = 0, x = root;
         while(true){
             if(s[x].val == val){
                 s[x].cnt ++;
                 maintain(x);
                 maintain(f);
                 splay(root,x,0);
                 return;
             }

             f = x;
             x = s[x].ch[s[x].val < val];
             if(!x) {
                 s[++tot].val = val;
                 s[tot].cnt = 1;
                 s[tot].fa = f;
                 s[f].ch[s[f].val < val] = tot;
                 maintain(tot);
                 maintain(f);
                 splay(root,tot,0);
                 return ;
             }
         }
     }

    inline int Find(int &rt,int k) {
        int res = 0,now = rt;
        while(true) {
            if(k<s[now].val) {
                now = s[now].ch[0];
            }else {
                //否则加上右子树的个数
                res += s[s[now].ch[0]].sz;
                //中序遍历,如果找到这个节点返回res+1
                if(k == s[now].val) {
                    splay(rt,now,0);
                    return res + 1;
                }
                res += s[now].cnt;
                now = s[now].ch[1];
            }
        }
    }

     int getPre(int rt){
         int now = s[rt].ch[0];
         while (s[now].ch[1]) now = s[now].ch[1];
         return now;
     }

     int getNxt(int rt){
         int now = s[rt].ch[1];
         while (s[now].ch[0]) now = s[now].ch[0];
         return now;
     }
     
    inline void del(int &rt,int k){
       Find(rt,k);//先让该点成为根节点
        if(s[rt].cnt > 1) {//如果大于1,不需要删除节点
            s[rt].cnt--;
            maintain(rt);
            return;
        }
        //如果只有一个点
        if(!s[rt].ch[0] && !s[rt].ch[1]){
            Clear(rt);
            rt = 0;
            return;
        }
        //没有左儿子,让右儿子成为根节点
        if(!s[rt].ch[0]){
            int tmp = rt;
            rt = s[rt].ch[1];
            s[rt].fa=0;
            Clear(tmp);
            return;
        }
        //没有右儿子,让左儿子成为根节点
        if(!s[rt].ch[1]){
            int tmp = rt;
            rt = s[rt].ch[0];
            s[rt].fa = 0;
            Clear(tmp);
            return;
        }
        //有左右儿子,让前驱成为根节点
        int x = getPre(rt) , now = rt;
        splay(rt,x,0);
        s[s[now].ch[1]].fa = x;
        s[x].ch[1] = s[now].ch[1];
        Clear(now);
        maintain(rt);
    }
}sp;
  • 问题1之前提到了,就是在Splay中插入这个结点,然后返回这个结点的左儿子的Size就行,记得减去无穷大的那个点。
int query_order(int p,int l,int r,int val){
    //查询顺序,就是查有多少个比他小
    if(l <= t[p].l && t[p].r <= r){
       sp.ins(t[p].rt,val);
       int res = s[s[t[p].rt].ch[0]].sz-1;
       sp.del(t[p].rt,val);
       return res;
    }
    int mid = (t[p].l + t[p].r) >> 1,res = 0;
    if(l <= mid) res += query_order(p << 1,l,r,val);
    if(mid < r) res += query_order(p << 1|1,l,r,val);
    return res;
}
  • 问题2求排名k的值,这需要用到二分,二分check函数就是问题1的query_order,在区间权值范围内二分,权值越大排名越大,就是在单调递增区间中查询小于k的数的最大值(因为有一个无穷小结点,所以不能小于等于)。二分模板也很明显:
int query_number(int L,int R,int val){
    int l = 1,r = getMax(1,L,R) ,mid,tmp;
    while(l < r){
        mid = (l + r + 1)>>1;
        tmp = query_order(1,L,R,mid);
        if(tmp < val){
            l = mid;
        }else{
            r = mid - 1;
        }
    }
    return l;
}
  • 问题3是修改,这个不难,这个点所在的所有线段树结点都要删除该点在Splay树上的结点,然后加入新值。
void modify(int p,int pos,int val){
     sp.del(t[p].rt,arr[pos]);
     sp.ins(t[p].rt,val);
     if(t[p].l == t[p].r){
         t[p].mx = val;
         t[p].mn = val;
         arr[pos] = val;
         return;
     }
     int mid = (t[p].l + t[p].r) >> 1;
     if(pos <= mid) modify(p << 1,pos,val);
     if(pos > mid) modify(p << 1 | 1,pos,val);
     pushUp(p);
}
  • 问题4,查询前驱,即查询每个线段树区间最大的比该数小的数,最后取个最大值。5同理
int query_Pre(int p,int l,int r,int val){
    if(l <= t[p].l && r >= t[p].r){
        sp.ins(t[p].rt,val);
        int res = s[sp.getPre(t[p].rt)].val;
        sp.del(t[p].rt,val);
        return res;
    }
    int res = -inf,mid = (t[p].l + t[p].r) >> 1;
    if(l <= mid)  res = max(res,query_Pre(p << 1,l,r,val));
    if(r > mid)  res = max(res,query_Pre(p << 1|1,l,r,val));
    return res;
}

int query_Nxt(int p,int l,int r,int val){
    if(l <= t[p].l && r >= t[p].r){
        sp.ins(t[p].rt,val);
        int res = s[sp.getNxt(t[p].rt)].val;
        sp.del(t[p].rt,val);
        return res;
    }
    int res = inf,mid = (t[p].l + t[p].r) >> 1;
    if(l <= mid)  res = min(res,query_Nxt(p << 1,l,r,val));
    if(r > mid)  res = min(res,query_Nxt(p << 1|1,l,r,val));
    return res;
}
  • 中途为了优化二分(也没什么用),还加了线段树查询最大值和最小值的
int getMax(int p,int l,int r){
    if(l <= t[p].l && t[p].r <=r) return t[p].mx;
    int mid = (t[p].l + t[p].r) >> 1,res = -inf;
    if(l <= mid) res = max(res,getMax(p << 1,l,r));
    if(mid < r) res = max(res,getMax(p << 1 | 1,l,r));
    return res;
}

int getMin(int p,int l,int r){
    if(l <= t[p].l && t[p].r <= r) return t[p].mx;
    int mid = (t[p].l + t[p].r) >> 1 ,res = inf;
    if(l <= mid) res = min(res,getMin(p << 1,l,r));
    if(mid < r) res = min(res,getMin(p << 1|1,l,r));
    return res;
}

完整代码

#pragma GCC optimize(2)
#pragma GCC optimize(3,"Ofast","inline")
#include<bits/stdc++.h>
using namespace std;
#define ll long long

const int N = 1e7+7;
const int inf = 2147483647;
int tot;//节点个数
struct node {
    int fa;//父亲节点
    int ch[2];//子节点
    int val;//权值
    int sz;//子树大小
    int cnt;
}s[N];
struct Tree{
    int rt,l,r,mx,mn;
}t[N];
int arr[N];
struct Splay {
     int get(int x) {return s[s[x].fa].ch[1] == x;}

     void Clear(int x) {
         s[x].fa = s[x].ch[0] = s[x].ch[1] = s[x].sz = s[x].val =0;
     }

     void maintain(int x){
         s[x].sz = s[s[x].ch[0]].sz + s[s[x].ch[1]].sz + s[x].cnt;
     }

     void Rorate(int x){
         int y = s[x].fa, z = s[y].fa, chk = get(x);

         s[y].ch[chk] = s[x].ch[chk ^ 1];
         s[s[x].ch[chk ^ 1]].fa = y;

         s[y].fa = x;
         s[x].ch[chk ^ 1] = y;

         s[x].fa =z;
         if(z) s[z].ch[s[z].ch[1] == y] = x;

         maintain(y);
         maintain(x);
     }

     void splay(int &rt,int x,int y){
         for(int f = s[x].fa;f != y;Rorate(x),f=s[x].fa){
             if(s[f].fa != y) Rorate(get(x) == get(f) ? f : x);
         }
         if(y==0) rt = x;
     }

     void ins(int &root ,int val){
         if(!root) {
             root = ++tot;
             s[root].val = val;
             s[root].cnt++;
             maintain(root);
             return ;
         }

         int f = 0, x = root;
         while(true){
             if(s[x].val == val){
                 s[x].cnt ++;
                 maintain(x);
                 maintain(f);
                 splay(root,x,0);
                 return;
             }

             f = x;
             x = s[x].ch[s[x].val < val];
             if(!x) {
                 s[++tot].val = val;
                 s[tot].cnt = 1;
                 s[tot].fa = f;
                 s[f].ch[s[f].val < val] = tot;
                 maintain(tot);
                 maintain(f);
                 splay(root,tot,0);
                 return ;
             }
         }
     }

    inline int Find(int &rt,int k) {
        int res = 0,now = rt;
        while(true) {
            if(k<s[now].val) {
                now = s[now].ch[0];
            }else {
                //否则加上右子树的个数
                res += s[s[now].ch[0]].sz;
                //中序遍历,如果找到这个节点返回res+1
                if(k == s[now].val) {
                    splay(rt,now,0);
                    return res + 1;
                }
                res += s[now].cnt;
                now = s[now].ch[1];
            }
        }
    }

     int getPre(int rt){
         int now = s[rt].ch[0];
         while (s[now].ch[1]) now = s[now].ch[1];
         return now;
     }

     int getNxt(int rt){
         int now = s[rt].ch[1];
         while (s[now].ch[0]) now = s[now].ch[0];
         return now;
     }

    inline void del(int &rt,int k){
       Find(rt,k);//先让该点成为根节点
        if(s[rt].cnt > 1) {//如果大于1,不需要删除节点
            s[rt].cnt--;
            maintain(rt);
            return;
        }
        //如果只有一个点
        if(!s[rt].ch[0] && !s[rt].ch[1]){
            Clear(rt);
            rt = 0;
            return;
        }
        //没有左儿子,让右儿子成为根节点
        if(!s[rt].ch[0]){
            int tmp = rt;
            rt = s[rt].ch[1];
            s[rt].fa=0;
            Clear(tmp);
            return;
        }
        //没有右儿子,让左儿子成为根节点
        if(!s[rt].ch[1]){
            int tmp = rt;
            rt = s[rt].ch[0];
            s[rt].fa = 0;
            Clear(tmp);
            return;
        }
        //有左右儿子,让前驱成为根节点
        int x = getPre(rt) , now = rt;
        splay(rt,x,0);
        s[s[now].ch[1]].fa = x;
        s[x].ch[1] = s[now].ch[1];
        Clear(now);
        maintain(rt);
    }
}sp;

void pushUp(int x){
    t[x].mx = max(t[x<<1].mx,t[x<<1|1].mx);
    t[x].mn = min(t[x<<1].mn,t[x<<1|1].mn);
}

void build(int p,int l,int r){
    t[p].l = l,t[p].r = r;
    //线段树每个结点建立一个splay
    sp.ins(t[p].rt,-inf);
    sp.ins(t[p].rt,inf);
    for(int i = l;i <= r;++i){
        sp.ins(t[p].rt,arr[i]);
    }
    if(l == r){ t[p].mx = t[p].mn = arr[l];return; }
    int mid = (l+r) >> 1;
    build(p<<1,l,mid);
    build(p<<1|1,mid + 1,r);
    pushUp(p);
}

int getMax(int p,int l,int r){
    if(l <= t[p].l && t[p].r <=r) return t[p].mx;
    int mid = (t[p].l + t[p].r) >> 1,res = -inf;
    if(l <= mid) res = max(res,getMax(p << 1,l,r));
    if(mid < r) res = max(res,getMax(p << 1 | 1,l,r));
    return res;
}

int getMin(int p,int l,int r){
    if(l <= t[p].l && t[p].r <= r) return t[p].mx;
    int mid = (t[p].l + t[p].r) >> 1 ,res = inf;
    if(l <= mid) res = min(res,getMin(p << 1,l,r));
    if(mid < r) res = min(res,getMin(p << 1|1,l,r));
    return res;
}

int query_order(int p,int l,int r,int val){
    //查询顺序,就是查有多少个比他小
    if(l <= t[p].l && t[p].r <= r){
       sp.ins(t[p].rt,val);
       int res = s[s[t[p].rt].ch[0]].sz-1;
       sp.del(t[p].rt,val);
       return res;
    }
    int mid = (t[p].l + t[p].r) >> 1,res = 0;
    if(l <= mid) res += query_order(p << 1,l,r,val);
    if(mid < r) res += query_order(p << 1|1,l,r,val);
    return res;
}

void modify(int p,int pos,int val){
     sp.del(t[p].rt,arr[pos]);
     sp.ins(t[p].rt,val);
     if(t[p].l == t[p].r){
         t[p].mx = val;
         t[p].mn = val;
         arr[pos] = val;
         return;
     }
     int mid = (t[p].l + t[p].r) >> 1;
     if(pos <= mid) modify(p << 1,pos,val);
     if(pos > mid) modify(p << 1 | 1,pos,val);
     pushUp(p);
}

int query_Pre(int p,int l,int r,int val){
    if(l <= t[p].l && r >= t[p].r){
        sp.ins(t[p].rt,val);
        int res = s[sp.getPre(t[p].rt)].val;
        sp.del(t[p].rt,val);
        return res;
    }
    int res = -inf,mid = (t[p].l + t[p].r) >> 1;
    if(l <= mid)  res = max(res,query_Pre(p << 1,l,r,val));
    if(r > mid)  res = max(res,query_Pre(p << 1|1,l,r,val));
    return res;
}

int query_Nxt(int p,int l,int r,int val){
    if(l <= t[p].l && r >= t[p].r){
        sp.ins(t[p].rt,val);
        int res = s[sp.getNxt(t[p].rt)].val;
        sp.del(t[p].rt,val);
        return res;
    }
    int res = inf,mid = (t[p].l + t[p].r) >> 1;
    if(l <= mid)  res = min(res,query_Nxt(p << 1,l,r,val));
    if(r > mid)  res = min(res,query_Nxt(p << 1|1,l,r,val));
    return res;
}

int query_number(int L,int R,int val){
    int l = 1,r = getMax(1,L,R) ,mid,tmp;
    while(l < r){
        mid = (l + r + 1)>>1;
        tmp = query_order(1,L,R,mid);
        if(tmp < val){
            l = mid;
        }else{
            r = mid - 1;
        }
    }
    return l;
}

int main(){
    int n,q,op,l,r,pos;
    scanf("%d%d",&n,&q);
    for(int i=1;i<=n;++i) scanf("%d",&arr[i]);
    build(1,1,n);
    while(q--){
        scanf("%d",&op);
        if(op == 1){
            scanf("%d%d%d",&l,&r,&pos);
            printf("%d\n",query_order(1,l,r,pos)+1);
        }else if(op == 2){
            scanf("%d%d%d",&l,&r,&pos);
            printf("%d\n",query_number(l,r,pos));
        }else if(op == 3){
            scanf("%d%d",&l,&pos);
            modify(1,l,pos);
        }else if(op == 4){
            scanf("%d%d%d",&l,&r,&pos);
            printf("%d\n",query_Pre(1,l,r,pos));
        }else if(op == 5){
            scanf("%d%d%d",&l,&r,&pos);
            printf("%d\n",query_Nxt(1,l,r,pos));
        }
    }
    return 0;
}

代码不加O2优化会超时,如果要优化的话,可以加个输入输出挂。

后记

博客两周年快乐。

这是第一篇博客https://www.cnblogs.com/smallocean/p/8525932.html:2018.3.7

发现自己留下的东西都可以当作时间胶囊,等未来某天翻看的时候,仿佛能看到那个时候的自己。