平衡树标准功能
(1)插入
(2)删除
(3)查询数 x 的排名
(4)查询排名为 x 的数
(5)求 x 的前驱
(6)求 x 的后继
题目为:zngg课程的第四讲第二题:普通平衡树
#include <bits/stdc++.h>
using namespace std;
const int maxn = 1e5 + 10;
int n, cnt, rt;
struct node{
int lc, rc;
int val, pri;
int num, sz;
}tr[maxn];
int NewNode(int v){
tr[++cnt].val = v;
tr[cnt].pri = rand();
tr[cnt].num = tr[cnt].sz = 1;
tr[cnt].lc = tr[cnt].rc = 0;
return cnt;
}
void Update(int p){
tr[p].sz = tr[tr[p].lc].sz + tr[tr[p].rc].sz + tr[p].num;
}
void Zig(int &p){
//Right Rotate
int q = tr[p].lc;
tr[p].lc = tr[q].rc;
tr[q].rc = p;
tr[q].sz = tr[p].sz;
Update(p);
p = q;
}
void Zag(int &p){
int q = tr[p].rc;
tr[p].rc = tr[q].lc;
tr[q].lc = p;
tr[q].sz = tr[p].sz;
Update(p);
p = q;
}
void Insert(int &p, int v){
if (!p){
p = NewNode(v);
return;
}
tr[p].sz++;
if (v == tr[p].val){
tr[p].num++;
return;
}
else if (v < tr[p].val){
Insert(tr[p].lc, v);
if (tr[p].pri < tr[tr[p].lc].pri)
Zig(p);
}
else{
//v > tr[p].val
Insert(tr[p].rc, v);
if (tr[p].pri < tr[tr[p].rc].pri)
Zag(p);
}
}
void Delete(int &p, int v){
if (!p) return;
tr[p].sz--;
if (v == tr[p].val){
if (tr[p].num > 1){
tr[p].num--;
return;
}
if (!tr[p].lc || !tr[p].rc)
p = tr[p].lc + tr[p].rc;
else if (tr[tr[p].lc].pri > tr[tr[p].rc].pri){
Zig(p);
Delete(tr[p].rc, v);
}
else{
Zag(p);
Delete(tr[p].lc, v);
}
return;
}
if (v < tr[p].val)
Delete(tr[p].lc, v);
else
Delete(tr[p].rc, v);
}
int GetPre(int v){
int p = rt;
int res = 0;
while(p){
if (tr[p].val < v){
res = tr[p].val;
p = tr[p].rc;
}
else
p = tr[p].lc;
}
return res;
}
int GetNxt(int v){
int p = rt;
int res = 0;
while(p){
if (tr[p].val > v){
res = tr[p].val;
p = tr[p].lc;
}
else
p = tr[p].rc;
}
return res;
}
int GetRankByVal(int p, int v){
if (!p) return 0;
if (tr[p].val == v)
return tr[tr[p].lc].sz + 1;
else if (v < tr[p].val)
return GetRankByVal(tr[p].lc, v);
else
return GetRankByVal(tr[p].rc, v) + tr[tr[p].lc].sz + tr[p].num;
}
int GetValByRank(int p, int rk){
if (!p) return 0;
if (tr[tr[p].lc].sz >= rk)
return GetValByRank(tr[p].lc, rk);
else if (tr[tr[p].lc].sz + tr[p].num >= rk)
return tr[p].val;
else
return GetValByRank(tr[p].rc, rk - tr[tr[p].lc].sz - tr[p].num);
}
int main(){
//freopen("input.txt", "r", stdin);
int x;
scanf("%d", &n);
while(n--){
int op, x;
scanf("%d%d", &op, &x);
if (op == 1) Insert(rt, x);
else if (op == 2) Delete(rt, x);
else if (op == 3) printf("%d\n", GetRankByVal(rt, x));
else if (op == 4) printf("%d\n", GetValByRank(rt, x));
else if (op == 5) printf("%d\n", GetPre(x));
else printf("%d\n", GetNxt(x));
}
return 0;
}



京公网安备 11010502036488号