#二叉搜索树
class TreeNode:
    def __init__(self,key,val,size = 1,color = None,left=None,right=None):
        #key符号表,val值 color:颜色,红或黑,此处没用,size 以此结点做根的树大小
        self.key=key
        self.val=val
        self.left = left
        self.right = right
        self.color = color
        self.size = size
    def __len__(self):
        return self.size
class BST:
    def __init__(self):
        self.root = None
        self.size = 0
    def __len__(self):
        #len()方法
        return self.size
    def __str__(self):
        #打印BST信息
        return 'BST size :(%d)' % self.size
    __repr__ = __str__
    def __contains__(self, key):
        #in方法
        res_val = self._find(self.root,key)
        if res_val==None:
            return False
        return True 
    #构造方法
    def __setitem__(self, key, val):
        return self.insert(key,val)
    #随机访问方法
    def __getitem__(self, key):
        if isinstance(key, int):
            return self._find_val(self.root,key)
        #和数组不同,-1代表key是-1,可以为负,而不是倒数的最后一个
        #可以切片
        if isinstance(key, slice):
            res = []
            d = 1
            if key.start>key.stop:
                d = -1
            for index in range(key.start,key.stop,d):
                res.append(self._find_val(self.root,index))
            return res
    #数据插入
    def insert(self,key,val):
        if not self.root:
            self.root = TreeNode(key,val)
            self.size = 1
            return True
        else:
            #cur = 0 or 1
            #利用python的浅拷贝
            cur = self._insert_help(key,val,self.root)
            self.size += cur
            if cur==1:
                return True
            return False
    def _insert_help(self,key,val,root):
        #二分,找不到创建新的,找得到覆盖
        if not root:
            root =  TreeNode(key,val)
            return 1
        if key < root.key:
            if root.left:
                cur = self._insert_help(key,val,root.left)
                root.size += cur
                return cur
            root.left = TreeNode(key,val)
            root.size += 1
            return 1
        if key > root.key:
            if root.right:
                cur = self._insert_help(key,val,root.right)
                root.size += cur
                return cur
            root.right = TreeNode(key,val)
            root.size += 1
            return 1
        if key == root.key:
            root.val = val
        return 0
    def pre_print(self,bottom = False):
        res = []
        self._pre_print(self.root,res)
        print('pre:',res)
    def _pre_print(self,root,res):
        if not root:
            return
        res.append(root.val)
        self._pre_print(root.left,res)
        self._pre_print(root.right,res)
        return
    def mid_print(self):
        res = []
        self._mid_print(self.root,res)
        print('mid:',res)
    def _mid_print(self,root,res):
        if not root:
            return
        self._mid_print(root.left,res)
        res.append(root.val)
        self._mid_print(root.right,res)
        return
    def las_print(self):
        res = []
        self._las_print(self.root,res)
        print('las:',res)
    def _las_print(self,root,res):
        if not root:
            return
        self._las_print(root.left,res)
        self._las_print(root.right,res)
        res.append(root.val)
        return
    #find
    def find_val(self,key):
        return self._find_val(self.root,key)
    def _find_val(self,root,key):
        if not root:
            return None
        if key<root.key:
             return self._find_val(root.left,key)
        if key>root.key:
             return self._find_val(root.right,key)
        if key==root.key:
             return root.val
    def find(self,key):
        return self._find(self.root,key)
    def _find(self,root,key):
        if not root:
            return None
        if key<root.key:
             return self._find(root.left,key)
        if key>root.key:
             return self._find(root.right,key)
        if key==root.key:
             return root
    #查找树的结点数目
    def _size(self,root):
        if not root:
            return 0
        return len(root)
    def rank(self,key):
        if not self.find(key):
            return 0
        return self._rank(self.root,key)
    def _rank(self,root,key):
        #key的rank
        if not root:
            return 0
        if key<root.key:
            return self._rank(root.left,key)
        if key==root.key:
            return 1 + self._size(root.left)
        if key>root.key:
            return 1 + self._size(root.left) + self._rank(root.right,key)
    #最大最小
    def max(self,root = None):
        return self._max(root).val
    def min(self,root = None):
        return self._min(root).val
    def _max(self,root=None):
        #最大
        if root is None:
            root = self.root
        while root.right:
            root = root.right
        return root
    def _min(self,root=None):
        #最小
        if root is None:
            root = self.root
        while root.left:
            root = root.left
        return root   
    #delete
    #最小结点删除,辅助函数
    def _delete_min(self,root):
        #root不能为空
        #root左子树为空,此时最小就是根结点,返回删除根结点后的右子树,即可将根结点删除
        if not root.left:
            #删除,此时size没有发生变换
            #self.size-=1
            return root.right
        #左子树不为空,往左子树里面删,返回左子树里面删除最小值后的左子树
        root.left = _delete_min(root.left)
        return root
    def delete(self,key):
        #空树
        if not self.root:
            return False
        res = [False]
        self.root = self._delete(self.root,key,res)
        return res[0]
    def _delete(self,root,key,res=[False]):
        if not root:
            return root
        if key<root.key:
            root.left = self._delete(root.left,key,res)
            root.size = self._size(root.left)+self._size(root.right)+1
            return root
        if key>root.key:
            root.right = self._delete(root.right,key,res)
            root.size = self._size(root.left)+self._size(root.right)+1
            return root
        #提前减小size,key==root.key,无论哪种情况root都会被删掉
        self.size-=1
        res[0] = True
        if not root.left:
            return root.right
        if not root.right:
            return root.left
        #双方结点都不空
        #简单的写法是直接把删除结点的左结点放到删除结点右子树最左叶子结点的左边,
        #用删除结点的右子树覆盖原先的删除结点,问题是树的高度增高了很多影响性能
        #即时删除
        #将指向即将被删除的节点的链接保存为t
        #将x指向后继节点min(t.right) #这一步就是原先结点的位置用所有大于他的右边最小的代替
        #右子树删除最小结点,因为最小结点在根点所有很容易删除掉
        #将x的左结点设为删除结点t的左结点
        t = root
        root = _min(t.right)
        root.right = self._delete_min(root.right)#此函数不会将self.size-1
        root.left = t.left
        #原先的root返回时被删掉
        return root
    def get_range(self,key_st,key_en):
        #左开右闭
        if key_en<=key_st or not self.root:
            return []
        res = []
        self._get_range(self.root,key_st,key_en,res)
        return res
    def _get_range(self,root,st,en,res):
        if not root:
            return
        #中序遍历
        if root.key>=st:
            self._get_range(root.left,st,en,res)
        if st<=root.key and root.key<en:
            res.append(root.val)
        if root.key<en:
            self._get_range(root.right,st,en,res)
        return


a = BST()
print(a)
#初始化
a[2]='b'
a[1]='a'
a.insert(3,'c'),a.insert(-1,'Y'),a.insert(0,'Z'),a.insert(5,'e'),a.insert(4,'d')
print('打印keys,vals:')
print([i for i in range(-1,5+1)],['Y','Z','a','b','c','d','e'])
print('len功能:%d。in功能\'2 in a\':%s。随机访问功能BST[2]:%s,BST[1:3]:%s'%(len(a),2 in a,a[2],a[1:3]))
print('前中后序遍历')
a.pre_print(),a.mid_print(),a.las_print()#中序遍历有序
print('find功能,0:%s,1:%s'%(a.find_val(0),a.find_val(1)))
print('最小值and最大值:',a.min(),a.max())
print('区域搜索功能[1-4):',a.get_range(1,4))
print('size:',a.root.size,a.root.left.size,a.root.right.size)
for i in [-1,0,1,2,3,4,5]:
    print('rank %s is %s'%(i,a.rank(i)))
print('delete key=1',a.delete(1),a[1])
for i in [-1,0,1,2,3,4,5]:
    print('rank %d is %d'%(i,a.rank(i)))


#删除操作
print('删除-1',a.delete(-1))
a.pre_print(),a.mid_print()
print('删除-1',a.delete(-1))
a.pre_print(),a.mid_print()
print('删除3',a.delete(3))
a.pre_print(),a.mid_print()
print('删除3',a.delete(3))
a.pre_print(),a.mid_print()