树套树是指在一个树形数据结构上,每个节点都不再是一个节点,而是另一种数据结构。
线段树可用于:点更新、区间更新、区间查询
平衡树可用于:第 k 小、排名第 k 的数、前驱、后继
将两者结合起来,就是线段树套平衡树
用线段树维护区间,再用平衡树维护对区间中的动态修改。
本题包括 5 种操作:
区间排名、区间第 k 小、点更新、区间前驱、区间后继
A: 区间排名
在线段树中查询 k 在 [ql, qr] 区间的排名,统计区间中小于 k 的元素个数,然后 +1 得到 k 在区间内的排名
[ql, qr]区间的话
依然按照线段树的分类方法,左区间递归、右区间递归
B:区间第 k 小
区间内的元素是无大小顺序的,所以不可以按照区间查找排名。只能按照值的大小来二分搜索。
C:点更新
除线段树上的操作外,还需要更新每个节点对应的平衡树,最后修改 p[pos] = k
D:查询 k 在 [ql, qr] 区间的前驱
若查询区间无交集,则是 -inf
若查询区间覆盖当前节点区间,则在当前节点的平衡树中查询 k 的前驱
否则在左右子树中搜索,求解 MAX
E:求后继,同前驱求法
#include <bits/stdc++.h>
using namespace std;
#define ls tr[x].ch[0]
#define rs tr[x].ch[1]
#define lson rt<<1, l, mid
#define rson rt<<1|1, mid + 1, r
const int maxn = 5e4 + 10;
const int inf = 2147483647;
int n, m, p[maxn], tot;
struct Node{
int val, key, cnt, sz, ch[2];
}tr[maxn * 30];
struct Treap{
int root;
void Update(int x){
tr[x].sz = tr[ls].sz + tr[rs].sz + tr[x].cnt;
}
void Rotate(int &x, bool c){
int son = tr[x].ch[c];
tr[x].ch[c] = tr[son].ch[c ^ 1];
tr[son].ch[c ^ 1] = x;
Update(x);
Update(x=son);
}
void Insert(int &x, int val){
if (!x){
x = ++tot;
tr[x].cnt = tr[x].sz = 1;
tr[x].key = rand();
tr[x].val = val;
return;
}
tr[x].sz++;
if (tr[x].val == val){
tr[x].cnt ++;
return;
}
bool c = val > tr[x].val;
Insert(tr[x].ch[c], val);
if (tr[x].key > tr[tr[x].ch[c]].key)
Rotate(x, c);
}
void Delete(int &x, int val){
if (!x) return;
if (tr[x].val == val){
if (tr[x].cnt > 1){
tr[x].cnt--;
tr[x].sz--;
return;
}
bool c = tr[ls].key > tr[rs].key;
if (ls == 0 || rs == 0)
x = ls + rs;
else
Rotate(x, c), Delete(x, val);
}
else{
tr[x].sz--;
Delete(tr[x].ch[tr[x].val < val], val);
}
}
int Rank(int x, int val){
if (!x) return 0;
if (tr[x].val == val) return tr[ls].sz;
else if (tr[x].val > val) return Rank(ls, val);
else return tr[ls].sz + tr[x].cnt + Rank(rs, val);
}
int Kth(int x, int k){
while(1){
if (k <= tr[ls].sz) x = ls;
else if (k > tr[ls].sz + tr[x].cnt) k-=tr[ls].sz + tr[x].cnt, x = rs;
else return tr[x].val;
}
}
int Pre(int x, int val){
if (!x) return -inf;
else if (tr[x].val >= val) return Pre(ls, val);
else return max(tr[x].val, Pre(rs, val));
}
int Nxt(int x, int val){
if (!x) return inf;
else if (tr[x].val <= val) return Nxt(rs, val);
else return min(tr[x].val, Nxt(ls, val));
}
}a[maxn << 2];
void Build(int rt, int l, int r){
a[rt].root = 0;
for(int i = l; i <= r; i++)
a[rt].Insert(a[rt].root, p[i]);
if (l == r) return;
int mid = (l + r) >> 1;
Build(lson);
Build(rson);
}
int QueryRank(int rt, int l, int r, int ql, int qr, int k){
if (l > qr || r < ql) return 0;
if (ql <= l && r <= qr) return a[rt].Rank(a[rt].root, k);
int ans = 0, mid = (l + r) >> 1;
ans += QueryRank(lson, ql, qr, k);
ans += QueryRank(rson, ql, qr, k);
return ans;
}
int QueryVal(int ql, int qr, int k){
int l = 0, r = 1e8, mid, ans = -1;
while(l <= r){
mid = (l + r) >> 1;
if (QueryRank(1, 1, n, ql, qr, mid) + 1 <= k) ans = mid, l = mid + 1;
else r = mid - 1;
}
return ans;
}
void Modify(int rt, int l, int r, int pos, int k){
if (pos < l || r < pos) return;
a[rt].Delete(a[rt].root, p[pos]);
a[rt].Insert(a[rt].root, k);
if (l == r) return;
int mid = (l + r) >> 1;
Modify(lson, pos, k);
Modify(rson, pos, k);
}
int QueryPre(int rt, int l, int r, int ql, int qr, int k){
if (l > qr || r < ql) return -inf;
if (ql <= l && r <= qr) return a[rt].Pre(a[rt].root, k);
int mid = (l + r) >> 1;
int ans = QueryPre(lson, ql, qr, k);
ans = max(ans, QueryPre(rson, ql, qr, k));
return ans;
}
int QueryNxt(int rt, int l, int r, int ql, int qr, int k){
if (l > qr || r < ql) return inf;
if (ql <= l && r <= qr) return a[rt].Nxt(a[rt].root, k);
int mid = (l + r) >> 1;
int ans = QueryNxt(lson, ql, qr, k);
ans = min(ans, QueryNxt(rson, ql, qr, k));
return ans;
}
int main(){
//freopen("input.txt", "r", stdin);
scanf("%d%d",&n, &m);
for(int i=1;i <= n;i++)
scanf("%d", &p[i]);
tot = 0;
Build(1, 1, n);
while(m--){
int opt, l, r, k, pos;
scanf("%d", &opt);
if(opt == 1){
scanf("%d%d%d", &l, &r, &k);
printf("%d\n", QueryRank(1, 1, n, l, r, k) + 1);
}
else if(opt == 2){
scanf("%d%d%d",&l, &r, &k);
printf("%d\n", QueryVal(l, r, k));
}
else if(opt == 3){
scanf("%d%d", &pos, &k);
Modify(1 ,1, n, pos, k);
p[pos] = k;
}
else if(opt == 4){
scanf("%d%d%d", &l, &r, &k);
printf("%d\n", QueryPre(1, 1, n, l, r, k));
}
else{
scanf("%d%d%d", &l, &r, &k);
printf("%d\n", QueryNxt(1, 1, n, l, r, k));
}
}
return 0;
}



京公网安备 11010502036488号