TOP-K问题

一堆数中,找出前k大,或者前k小问题。一般来说,我们先sort,然后取值即可。目前解决TOP-K问题最有效的算法即是BFPRT算法,又称为中位数的中位数算法,该算法由Blum、Floyd、Pratt、Rivest、Tarjan提出,最坏时间复杂度为O(n)。不过想想有点不可思议,那k=n时岂不是也时间复杂度只有O(n)?
核心是:修改快速排序中pivot的选取

不得不说,什么CSDN里各种博客都是相互抄,看了很多都是同一篇帖子,说明白的没几个,以下是我总结的。

算法步骤:

  1. 将n个元素每5个一组,分成n/5(上界)组,我到现在还在想为什么是5个一组,三个1组行不行,10个一组行不行...
  2. 每组排序,找出中位数,5个一组的话,下标就是[2],最后一组不用给中位数
  3. 递归的调用,直到最后只有一个中位数,设为轴心,pivot,偶数个中位数的情况下设定为选取中间小的一个。在递归的过程中,最后肯定是构成的中位数<=5,最终返回的就是这个里面的中位数。
  4. 用pivot来分割数组,left=[<=pivot],[pivot],right=[>pivot]。
  5. 若len(left)+1==k,说明pivot就是要找的第k小的数,返回pivot即可;
    若len(left)+1<k,在right中递归整个过程,topK(arr,k-len(left)+1)
    若len(left)+1>k,在left中递归整个过程,topK(arr,k);
  6. 再严谨点的话可以加上异常处理

我们来试着按照说明做一下:

class Solution:
    def topK(self,arr,k): # 定义主体函数
        pivot = self.selectMid(arr) # 选取pivot
        arr.remove(pivot) # 剔除pivot
        left = [x for x in arr if x<=pivot] 
        right = [x for x in arr if x>pivot]
        indexOfPivot=len(left)+1
        if indexOfPivot==k:
            return pivot
        elif indexOfPivot<k:
            return self.topK(right,k-indexOfPivot)
        else:
            return self.topK(left,k)

    def selectMid(self,arr): # 定义函数,最终给出唯一的pivot
        # 退出条件
        if len(arr)<=5:
            arr = self.insertSort(arr)
            return arr[int(len(arr)/2) if len(arr)&1==1 else int(len(arr)/2)-1]
        # 非退出条件,正常循环体  
        n=int(len(arr)/5)
        tmp = []
        for num in range(n):
            i=num*5  # 这一步很容易忽视
            tmp.append(self.insertSort(arr[i:i+5])[2])
        return self.selectMid(tmp)

    def insertSort(self, arr): # 返回插序排列
        for i in range(1,len(arr)):
            key=arr[i]
            j=i-1
            while j>=0 and key<arr[j]:
                arr[j+1]=arr[j]
                j-=1
            arr[j+1]=key
        return arr 

如果用sort()函数,代码会更短,但是用了sort还有什么意义呢,不如一开始就sort()掉然后按顺序取得了,所以要忍住用现成sort()的手,尽量理解BFPRT算法的精髓吧。
下面也给出用sort()函数简化版:

class Solution:
    def topK(self,arr,k):
        pivot = self.selectMid(arr)
        arr.remove(pivot)

        left = [x for x in arr if x<=pivot]
        right = [x for x in arr if x>pivot]
        indexOfPivot=len(left)+1

        if indexOfPivot==k:
            return pivot
        elif indexOfPivot<k:
            return self.topK(right,k-indexOfPivot)
        else:
            return self.topK(left,k)


    def selectMid(self,arr): # 取出x
        # 退出条件
        if len(arr)<=5:
            arr.sort()
            return arr[int(len(arr)/2) if len(arr)&1==1 else int(len(arr)/2)-1]

        n=int(len(arr)/5)
        tmp = []

        for num in range(n):
            i=num*5
            tmp.append(sorted(arr[i:i+5])[2])
        return self.selectMid(tmp)