树套树是指在一个树形数据结构上,每个节点都不再是一个节点,而是另一种数据结构。
线段树可用于:点更新、区间更新、区间查询
平衡树可用于:第 k 小、排名第 k 的数、前驱、后继
将两者结合起来,就是线段树套平衡树
用线段树维护区间,再用平衡树维护对区间中的动态修改。
本题包括 5 种操作:
区间排名、区间第 k 小、点更新、区间前驱、区间后继
A: 区间排名
在线段树中查询 k 在 [ql, qr] 区间的排名,统计区间中小于 k 的元素个数,然后 +1 得到 k 在区间内的排名
[ql, qr]区间的话
依然按照线段树的分类方法,左区间递归、右区间递归
B:区间第 k 小
区间内的元素是无大小顺序的,所以不可以按照区间查找排名。只能按照值的大小来二分搜索。
C:点更新
除线段树上的操作外,还需要更新每个节点对应的平衡树,最后修改 p[pos] = k
D:查询 k 在 [ql, qr] 区间的前驱
若查询区间无交集,则是 -inf
若查询区间覆盖当前节点区间,则在当前节点的平衡树中查询 k 的前驱
否则在左右子树中搜索,求解 MAX
E:求后继,同前驱求法
#include <bits/stdc++.h> using namespace std; #define ls tr[x].ch[0] #define rs tr[x].ch[1] #define lson rt<<1, l, mid #define rson rt<<1|1, mid + 1, r const int maxn = 5e4 + 10; const int inf = 2147483647; int n, m, p[maxn], tot; struct Node{ int val, key, cnt, sz, ch[2]; }tr[maxn * 30]; struct Treap{ int root; void Update(int x){ tr[x].sz = tr[ls].sz + tr[rs].sz + tr[x].cnt; } void Rotate(int &x, bool c){ int son = tr[x].ch[c]; tr[x].ch[c] = tr[son].ch[c ^ 1]; tr[son].ch[c ^ 1] = x; Update(x); Update(x=son); } void Insert(int &x, int val){ if (!x){ x = ++tot; tr[x].cnt = tr[x].sz = 1; tr[x].key = rand(); tr[x].val = val; return; } tr[x].sz++; if (tr[x].val == val){ tr[x].cnt ++; return; } bool c = val > tr[x].val; Insert(tr[x].ch[c], val); if (tr[x].key > tr[tr[x].ch[c]].key) Rotate(x, c); } void Delete(int &x, int val){ if (!x) return; if (tr[x].val == val){ if (tr[x].cnt > 1){ tr[x].cnt--; tr[x].sz--; return; } bool c = tr[ls].key > tr[rs].key; if (ls == 0 || rs == 0) x = ls + rs; else Rotate(x, c), Delete(x, val); } else{ tr[x].sz--; Delete(tr[x].ch[tr[x].val < val], val); } } int Rank(int x, int val){ if (!x) return 0; if (tr[x].val == val) return tr[ls].sz; else if (tr[x].val > val) return Rank(ls, val); else return tr[ls].sz + tr[x].cnt + Rank(rs, val); } int Kth(int x, int k){ while(1){ if (k <= tr[ls].sz) x = ls; else if (k > tr[ls].sz + tr[x].cnt) k-=tr[ls].sz + tr[x].cnt, x = rs; else return tr[x].val; } } int Pre(int x, int val){ if (!x) return -inf; else if (tr[x].val >= val) return Pre(ls, val); else return max(tr[x].val, Pre(rs, val)); } int Nxt(int x, int val){ if (!x) return inf; else if (tr[x].val <= val) return Nxt(rs, val); else return min(tr[x].val, Nxt(ls, val)); } }a[maxn << 2]; void Build(int rt, int l, int r){ a[rt].root = 0; for(int i = l; i <= r; i++) a[rt].Insert(a[rt].root, p[i]); if (l == r) return; int mid = (l + r) >> 1; Build(lson); Build(rson); } int QueryRank(int rt, int l, int r, int ql, int qr, int k){ if (l > qr || r < ql) return 0; if (ql <= l && r <= qr) return a[rt].Rank(a[rt].root, k); int ans = 0, mid = (l + r) >> 1; ans += QueryRank(lson, ql, qr, k); ans += QueryRank(rson, ql, qr, k); return ans; } int QueryVal(int ql, int qr, int k){ int l = 0, r = 1e8, mid, ans = -1; while(l <= r){ mid = (l + r) >> 1; if (QueryRank(1, 1, n, ql, qr, mid) + 1 <= k) ans = mid, l = mid + 1; else r = mid - 1; } return ans; } void Modify(int rt, int l, int r, int pos, int k){ if (pos < l || r < pos) return; a[rt].Delete(a[rt].root, p[pos]); a[rt].Insert(a[rt].root, k); if (l == r) return; int mid = (l + r) >> 1; Modify(lson, pos, k); Modify(rson, pos, k); } int QueryPre(int rt, int l, int r, int ql, int qr, int k){ if (l > qr || r < ql) return -inf; if (ql <= l && r <= qr) return a[rt].Pre(a[rt].root, k); int mid = (l + r) >> 1; int ans = QueryPre(lson, ql, qr, k); ans = max(ans, QueryPre(rson, ql, qr, k)); return ans; } int QueryNxt(int rt, int l, int r, int ql, int qr, int k){ if (l > qr || r < ql) return inf; if (ql <= l && r <= qr) return a[rt].Nxt(a[rt].root, k); int mid = (l + r) >> 1; int ans = QueryNxt(lson, ql, qr, k); ans = min(ans, QueryNxt(rson, ql, qr, k)); return ans; } int main(){ //freopen("input.txt", "r", stdin); scanf("%d%d",&n, &m); for(int i=1;i <= n;i++) scanf("%d", &p[i]); tot = 0; Build(1, 1, n); while(m--){ int opt, l, r, k, pos; scanf("%d", &opt); if(opt == 1){ scanf("%d%d%d", &l, &r, &k); printf("%d\n", QueryRank(1, 1, n, l, r, k) + 1); } else if(opt == 2){ scanf("%d%d%d",&l, &r, &k); printf("%d\n", QueryVal(l, r, k)); } else if(opt == 3){ scanf("%d%d", &pos, &k); Modify(1 ,1, n, pos, k); p[pos] = k; } else if(opt == 4){ scanf("%d%d%d", &l, &r, &k); printf("%d\n", QueryPre(1, 1, n, l, r, k)); } else{ scanf("%d%d%d", &l, &r, &k); printf("%d\n", QueryNxt(1, 1, n, l, r, k)); } } return 0; }