TOP-K问题
一堆数中,找出前k大,或者前k小问题。一般来说,我们先sort,然后取值即可。目前解决TOP-K问题最有效的算法即是BFPRT算法,又称为中位数的中位数算法,该算法由Blum、Floyd、Pratt、Rivest、Tarjan提出,最坏时间复杂度为O(n)。不过想想有点不可思议,那k=n时岂不是也时间复杂度只有O(n)?
核心是:修改快速排序中pivot的选取。
不得不说,什么CSDN里各种博客都是相互抄,看了很多都是同一篇帖子,说明白的没几个,以下是我总结的。
算法步骤:
- 将n个元素每5个一组,分成n/5(上界)组,我到现在还在想为什么是5个一组,三个1组行不行,10个一组行不行...
- 每组排序,找出中位数,5个一组的话,下标就是[2],最后一组不用给中位数
- 递归的调用,直到最后只有一个中位数,设为轴心,pivot,偶数个中位数的情况下设定为选取中间小的一个。在递归的过程中,最后肯定是构成的中位数<=5,最终返回的就是这个里面的中位数。
- 用pivot来分割数组,left=[<=pivot],[pivot],right=[>pivot]。
- 若len(left)+1==k,说明pivot就是要找的第k小的数,返回pivot即可;
若len(left)+1<k,在right中递归整个过程,topK(arr,k-len(left)+1)
若len(left)+1>k,在left中递归整个过程,topK(arr,k);- 再严谨点的话可以加上异常处理
我们来试着按照说明做一下:
class Solution: def topK(self,arr,k): # 定义主体函数 pivot = self.selectMid(arr) # 选取pivot arr.remove(pivot) # 剔除pivot left = [x for x in arr if x<=pivot] right = [x for x in arr if x>pivot] indexOfPivot=len(left)+1 if indexOfPivot==k: return pivot elif indexOfPivot<k: return self.topK(right,k-indexOfPivot) else: return self.topK(left,k) def selectMid(self,arr): # 定义函数,最终给出唯一的pivot # 退出条件 if len(arr)<=5: arr = self.insertSort(arr) return arr[int(len(arr)/2) if len(arr)&1==1 else int(len(arr)/2)-1] # 非退出条件,正常循环体 n=int(len(arr)/5) tmp = [] for num in range(n): i=num*5 # 这一步很容易忽视 tmp.append(self.insertSort(arr[i:i+5])[2]) return self.selectMid(tmp) def insertSort(self, arr): # 返回插序排列 for i in range(1,len(arr)): key=arr[i] j=i-1 while j>=0 and key<arr[j]: arr[j+1]=arr[j] j-=1 arr[j+1]=key return arr
如果用sort()函数,代码会更短,但是用了sort还有什么意义呢,不如一开始就sort()掉然后按顺序取得了,所以要忍住用现成sort()的手,尽量理解BFPRT算法的精髓吧。
下面也给出用sort()函数简化版:
class Solution: def topK(self,arr,k): pivot = self.selectMid(arr) arr.remove(pivot) left = [x for x in arr if x<=pivot] right = [x for x in arr if x>pivot] indexOfPivot=len(left)+1 if indexOfPivot==k: return pivot elif indexOfPivot<k: return self.topK(right,k-indexOfPivot) else: return self.topK(left,k) def selectMid(self,arr): # 取出x # 退出条件 if len(arr)<=5: arr.sort() return arr[int(len(arr)/2) if len(arr)&1==1 else int(len(arr)/2)-1] n=int(len(arr)/5) tmp = [] for num in range(n): i=num*5 tmp.append(sorted(arr[i:i+5])[2]) return self.selectMid(tmp)