平衡树标准功能
(1)插入
(2)删除
(3)查询数 x 的排名
(4)查询排名为 x 的数
(5)求 x 的前驱
(6)求 x 的后继
题目为:zngg课程的第四讲第二题:普通平衡树
#include <bits/stdc++.h> using namespace std; const int maxn = 1e5 + 10; int n, cnt, rt; struct node{ int lc, rc; int val, pri; int num, sz; }tr[maxn]; int NewNode(int v){ tr[++cnt].val = v; tr[cnt].pri = rand(); tr[cnt].num = tr[cnt].sz = 1; tr[cnt].lc = tr[cnt].rc = 0; return cnt; } void Update(int p){ tr[p].sz = tr[tr[p].lc].sz + tr[tr[p].rc].sz + tr[p].num; } void Zig(int &p){ //Right Rotate int q = tr[p].lc; tr[p].lc = tr[q].rc; tr[q].rc = p; tr[q].sz = tr[p].sz; Update(p); p = q; } void Zag(int &p){ int q = tr[p].rc; tr[p].rc = tr[q].lc; tr[q].lc = p; tr[q].sz = tr[p].sz; Update(p); p = q; } void Insert(int &p, int v){ if (!p){ p = NewNode(v); return; } tr[p].sz++; if (v == tr[p].val){ tr[p].num++; return; } else if (v < tr[p].val){ Insert(tr[p].lc, v); if (tr[p].pri < tr[tr[p].lc].pri) Zig(p); } else{ //v > tr[p].val Insert(tr[p].rc, v); if (tr[p].pri < tr[tr[p].rc].pri) Zag(p); } } void Delete(int &p, int v){ if (!p) return; tr[p].sz--; if (v == tr[p].val){ if (tr[p].num > 1){ tr[p].num--; return; } if (!tr[p].lc || !tr[p].rc) p = tr[p].lc + tr[p].rc; else if (tr[tr[p].lc].pri > tr[tr[p].rc].pri){ Zig(p); Delete(tr[p].rc, v); } else{ Zag(p); Delete(tr[p].lc, v); } return; } if (v < tr[p].val) Delete(tr[p].lc, v); else Delete(tr[p].rc, v); } int GetPre(int v){ int p = rt; int res = 0; while(p){ if (tr[p].val < v){ res = tr[p].val; p = tr[p].rc; } else p = tr[p].lc; } return res; } int GetNxt(int v){ int p = rt; int res = 0; while(p){ if (tr[p].val > v){ res = tr[p].val; p = tr[p].lc; } else p = tr[p].rc; } return res; } int GetRankByVal(int p, int v){ if (!p) return 0; if (tr[p].val == v) return tr[tr[p].lc].sz + 1; else if (v < tr[p].val) return GetRankByVal(tr[p].lc, v); else return GetRankByVal(tr[p].rc, v) + tr[tr[p].lc].sz + tr[p].num; } int GetValByRank(int p, int rk){ if (!p) return 0; if (tr[tr[p].lc].sz >= rk) return GetValByRank(tr[p].lc, rk); else if (tr[tr[p].lc].sz + tr[p].num >= rk) return tr[p].val; else return GetValByRank(tr[p].rc, rk - tr[tr[p].lc].sz - tr[p].num); } int main(){ //freopen("input.txt", "r", stdin); int x; scanf("%d", &n); while(n--){ int op, x; scanf("%d%d", &op, &x); if (op == 1) Insert(rt, x); else if (op == 2) Delete(rt, x); else if (op == 3) printf("%d\n", GetRankByVal(rt, x)); else if (op == 4) printf("%d\n", GetValByRank(rt, x)); else if (op == 5) printf("%d\n", GetPre(x)); else printf("%d\n", GetNxt(x)); } return 0; }