前言

在学习KNN算法的时候需要使用到大顶队这个数据结构,这里先来实现一下。

什么是大顶堆

如图所示,大顶堆就是将所有的数据放到一颗完全二叉树中,但是需要满足每个父节点的值都大于子节点的值。

算法实现

1.建立大顶堆

大顶堆的建立,需要制定当前这个大顶堆的大小为0,将来不能超过多大。由于大顶堆是完全二叉树,我们在进行层次遍历的时候元素都是连续的,中间不存在空位,所以我们可以用数组来表示这棵树。那么,我们就再开辟一个数组,来存储大顶堆的元素。

2.元素上浮和下沉

如果大顶堆的最后一个元素比父节点的值还大,那么这个节点就需要和父节点交换位置,如果满足当前节点值比父节点值大,就一直执行这个操作。下沉就是上浮过程反着来,从树的根节点开始。

3.插入元素

在大顶堆中插入一个元素,分为以下两种情况:

  • 堆未满,将元素放在当前最后一个元素的后面,然后执行上浮过程。
  • 堆满了,如果该元素大于堆顶则无法插入,小于堆顶则替换堆顶,再执行下沉过程。

4.推出堆顶元素

大顶堆的交换顶部元素A和最后一个元素B,堆的size减1,再将顶部的B执行下沉过程,最后返回元素A。注意,虽然堆的size减小了1,但实际上并没有元素被删除,数组长度也没有任何变化,被pop的元素只是被放在了数组中size之后的位置。

大顶堆的应用

  • 堆排序降序,我们把顶点元素不停地pop出来,由于每次pop出的元素都是当时最大的,所以把pop的值收集起来就是一个降序数组。
  • 堆排序升序,同方法1,由于顶点元素每次都被pop方法放在了数组的最后一个元素的位置,所以全部pop完毕之后堆中的数组已经是一个升序数组
  • 从N个元素中查找最小的K个元素,把N个元素逐个插入大小为K的大顶堆中,最后大顶堆中的元素就是我们要找的TOP K

代码实现

#coding=utf-8
from time import time
from copy import copy
from random import randint

# 产生一个[low,high)区间的随机数组
def gen_data(low, high, n_rows, n_cols=None):
    if n_cols is None:
        ret = [randint(low, high) for _ in range(n_rows)]
    else:
        ret = [[randint(low, high) for _ in range(n_cols)]
               for _ in range(n_rows)]
    return ret

class MaxHeap(object):
    # 创建MaxHeap类
    def __init__(self, max_size, fn):
        self.max_size = max_size
        self.fn = fn
        self._items = [None] * max_size
        self.size = 0
    # 打印对象中具体的属性值
    def __str__(self):
        item_values = str([self.fn(self.items[i]) for i in range(self.size)])
        return ("Size: %d\nMax size: %d\nItem_values: %s\n" % (self.size, self.max_size, item_values))
    # 获取所有大顶堆的所有值
    @property
    def items(self):
        return self._items[:self.size]
    # 检查大顶堆是否已满
    @property
    def full(self):
        return self.size == self.max_size
    # 获取大顶堆的idx位置的值,如果被删除了,返回-inf
    def value(self, idx):
        item = self._items[idx]
        if item is None:
            ret = -float('inf')
        else:
            ret = self.fn(item)
        return ret
    # 添加元素
    def add(self, item):
        if self.full:
            if self.fn(item) < self.value(0):
                self._items[0] = item
                self.shift_down(0)
        else:
            self._items[self.size] = item
            self.size += 1
            self.shift_up(self.size - 1)
    # 推出顶部元素
    def pop(self):
        assert self.size > 0, "Cannot pop item! The MaxHeap is empty!"
        ret = self.items[0]
        self._items[0] = self._items[self.size - 1]
        self._items[self.size - 1] = None
        self.size -= 1
        self.shit_down(0)
        return ret
    # 元素上浮
    def shift_up(self, idx):
        assert idx < self.size, "The parameter idx must be less than heap's size!"
        parent = (idx - 1) // 2
        while parent >= 0 and self.value(parent) < self.value(idx):
            self._items[parent], self._items[idx] = self._items[idx], self._items[parent]
            idx = parent
            parent = (idx - 1) // 2

    # 元素下沉
    def shift_down(self, idx):
        child = (idx + 1) * 2 - 1
        while child < self.size:
            if child + 1 < self.size and self.value(child + 1) > self.value(child):
                child += 1
            if self.value(idx) < self.value(child):
                self._items[idx], self._items[child] = self._items[child], self._items[idx]
                idx = child
                child = (idx + 1) * 2 - 1
            else:
                break
    # 检查有效性
    def is_valid(self):
        ret = []
        for i in range(1, self.size):
            parent = (i - 1) // 2
            ret.append(self.value(parent) >= self.value(i))
        # all()函数用于判定可迭代参数iterable中的所有元素是否都为TRUE,如果是返回True,否则返回False
        return all(ret)

# 暴力查找nums中最小的k各元素
def exhausted_search(nums, k):
    rets = []
    idxs = []
    key = None
    for _ in range(k):
        val = float("inf")
        for i, num in enumerate(nums):
            if num < val and i not in idxs:
                key = i
                val = num
        idxs.append(key)
        rets.append(val)
    return rets

# 主函数分为下面几个部分
# 1. 随机生成数据集,即测试用例
# 2. 建立大顶堆
# 3. 调用exhausted_search查找
# 4. 使用大顶堆
def main():
    # Test
    print("Testing MaxHeap...")
    test_times = 100
    run_time_1 = run_time_2 = 0
    for _ in range(test_times):
        # Generate dataset randomly
        low = 0
        high = 1000
        n_rows = 10000
        k = 100
        nums = gen_data(low, high, n_rows)

        # Build Max Heap
        heap = MaxHeap(k, lambda x: x)
        start = time()
        for num in nums:
            heap.add(num)
        ret1 = copy(heap.items)
        run_time_1 += time() - start

        # Exhausted search
        start = time()
        ret2 = exhausted_search(nums, k)
        run_time_2 += time() - start

        # Compare result
        ret1.sort()
        assert ret1 == ret2, "target:%s\nk:%d\nrestult1:%s\nrestult2:%s\n" % (
            str(nums), k, str(ret1), str(ret2))
    print("%d tests passed!" % test_times)
    print("Max Heap Search %.2f s" % run_time_1)
    print("Exhausted search %.2f s" % run_time_2)

main()

我们来看一下搜索前100小的数的时间消耗(相对于暴力查找):
Testing MaxHeap…
100 tests passed!
Max Heap Search 1.80 s
Exhausted search 9.71 s

结论

可以看到使用MaxHeap算法可以帮助我们大量节省搜索时间,这在后面的KNN算法中有应用。