Splay这个名字在我很早的时候就听说过了,是寒假的时候在lyd的蓝书上学平衡树章节的时候,他在那一章节的最后强力推荐的一种平衡树。"Splay(伸展树)灵活多变,应用广泛,能够很方便地支持各种动态的区间操作,是用于解决复杂问题的一个重要的高级数据结构。"(原话)。 当时就使我就对splay满怀期待,但是碍于各种原因一直没有专门花一块时间去搞懂它就一直拖着(主要是我懒)。这次刚好借算法作业的机会算是完整的去了解它一次了。
Splay简介:
它是通过一种通过splay(伸展)操作来实现平衡的一种平衡二叉查找树。而这个操作就基于一种叫rotate(旋转)的操作来实现的。
Splay最大的特点:
每次对一个结点操作完后,都将其splay到根结点,比如我如果连续的查找某个数10次,那么第一次它会像普通的平衡树一样去找到它。但是找完后它还会附加一个操作就是将该结点通过一系列的旋转使他成为根结点。因为每次查询是从根结点开始的,所以我接下来的9次查询复杂度就是O(1)。(有点夸张
但是尽管不是一些连续相同操作那最坏情况也是单次操作只花费log n的时间。(的确优秀
具体实现代码:
struct Splay { struct Node { int l, r;//左右儿子的结点编号 int size;//当前结点的所存储的值、以该点为根的子树大小 ll val; int cnt;//当前结点的重复次数 }spl[N]; int tot = 0, root = 0;//内存池计数器、根结点编号 ll MAX = 0x3f3f3f3f3f3f3f3f; ll MIN = -0x3f3f3f3f3f3f3f3f; void init() { tot = root = 0; insert(root, MIN); insert(root, MAX); } void newnode(int& now, ll& val) { spl[now = ++tot].val = val; spl[now].size++;//树的大小=1 spl[now].cnt++;//结点重复次数=1 } void update(int now) {//更新size spl[now].size = spl[spl[now].l].size + spl[spl[now].r].size + spl[now].cnt;//树的大小(size) = 左子树的大小 + 右子树的大小+该结点的重复次数 } void zig(int& now) {//右旋,右旋拎左(该结点的左儿子)右(左儿子的右儿子)挂左(该结点的左) int l = spl[now].l; spl[now].l = spl[l].r; spl[l].r = now; now = l; update(spl[now].r), update(now);//更新,注意更新顺序 } void zag(int& now) {//左旋,左旋拎右(该结点的右儿子)左(右儿子的左儿子)挂右(该结点的右) int r = spl[now].r; spl[now].r = spl[r].l; spl[r].l = now; now = r; update(spl[now].l), update(now);//更新,注意更新顺序 } void splaying(int x, int& y) {//我要把x伸展到y那个位置! if (x == y)return;//如果到了终点,return int& l = spl[y].l, & r = spl[y].r;//临时变量 if (x == l)//如果左儿子是终点,那就单旋(右旋) zig(y); else if (x == r)//如果右儿子是终点也是单旋(左旋) zag(y); else {//否则就一定是双旋 if (spl[x].val < spl[y].val) { if (spl[x].val < spl[l].val) splaying(x, spl[l].l), zig(y), zig(y);//zigzig情况 else splaying(x, spl[l].r), zag(l), zig(y);//zagzig情况 } else { if (spl[x].val > spl[r].val) splaying(x, spl[r].r), zag(y), zag(y);//zagzag情况 else splaying(x, spl[r].l), zig(r), zag(y);//zigzag情况 } } } void insert(int& now, ll& val) {//插入新结点后需要将该点伸展到根结点 if (!now)newnode(now, val), splaying(now, root);//如果当前结点为空说明找到了合适的位置可以插入 else if (val < spl[now].val)insert(spl[now].l, val);//如果值小于当前结点就往当前结点的左儿子插 else if (val > spl[now].val)insert(spl[now].r, val);//如果值大于当前结点就往当前结点的右儿子插 else spl[now].size++, spl[now].cnt++, splaying(now, root);//否则该点就是我要插入的值 } void del(int now, ll& val) {//删除操作,先找到要删除的结点然后将其伸展到根结点。 if (spl[now].val == val)delnode(now); else if (val < spl[now].val)del(spl[now].l, val); else del(spl[now].r, val); } void delnode(int now) { splaying(now, root);//将要删除的结点伸展至根节点 if (spl[now].cnt > 1)spl[now].size--, spl[now].cnt--; else if (spl[root].r) {//如果当前结点(即根节点)有后继 int p = spl[root].r; while (spl[p].l)p = spl[p].l;//找到后继 splaying(p, spl[root].r);//将其伸展至根结点的右儿子 spl[spl[root].r].l = spl[root].l;//将根结点的左儿子变为根结点的右儿子的左儿子,从而根节点删除,它的右儿子成为新的根结点 root = spl[root].r;//原来根节点的右儿子成为新的右儿子 update(root); } else root = spl[root].l;//伸展之后没有后继,说明它是最大的了,直接删根节点,它的左儿子成为新的根结点 } ll k_min(int rank) {//查找第rank小的数 ++rank; int now = root; while (now) { int lsize = spl[spl[now].l].size; if (lsize + 1 <= rank && rank <= lsize + spl[now].cnt) { //如果在这个范围内,那就是当前结点 splaying(now, root); break; } else if (lsize >= rank) now = spl[now].l; else { rank -= lsize + spl[now].cnt; now = spl[now].r; } } return spl[now].val; } ll k_max(int rank) {//查找第rank大的数 ++rank; int now = root; while (now) { int rsize = spl[spl[now].r].size; if (rsize + 1 <= rank && rank <= rsize + spl[now].cnt) { //如果在这个范围内,那就是当前结点 splaying(now, root); break; } else if (rsize >= rank) { now = spl[now].r; } else { rank -= rsize + spl[now].cnt; now = spl[now].l; } } return spl[now].val; } int getrank(int val) {//查找值为val的排名 int now = root, rank = 1; while (now) { if (spl[now].val == val) { //找到了要的结点,这个之前的没有 rank += spl[spl[now].l].size; splaying(now, root); break; } else if (val < spl[now].val) now = spl[now].l; else { rank += spl[spl[now].l].size + spl[now].cnt; now = spl[now].r; } } return rank - 1; } ll GetPre(ll val) { int ans = 1; int p = root; while (p) { if (val == spl[p].val) { if (spl[p].l) { p = spl[p].l; while (spl[p].r > 0) p = spl[p].r; ans = p; } break; } if (spl[p].val < val && spl[p].val > spl[ans].val) ans = p; p = val < spl[p].val ? spl[p].l : spl[p].r; } return spl[ans].val; } ll GetNext(ll val) { int ans = 2; int p = root; while (p) { if (val == spl[p].val) { if (spl[p].r > 0) { p = spl[p].r; while (spl[p].l > 0) p = spl[p].l; ans = p; } break; } if (spl[p].val > val && spl[p].val < spl[ans].val) ans = p; p = val < spl[p].val ? spl[p].l : spl[p].r; } return spl[ans].val; } }spl;