1. 题目

2. 解答

2.1. 方法一——大顶堆

参考 堆和堆排序 以及 堆的应用,我们将数组的前 K 个位置当作一个大顶堆。

首先建堆,也即对堆中 [0, (K-2)/2] 的节点从上往下进行堆化。第 K/2 个节点若有子节点,其左子节点位置应该为 2 * K/2 + 1 = K+1,而我们堆中的最大位置为 K-1,显然第 K/2 个节点是第一个叶子节点,不用堆化。

建完堆之后,我们顺序访问原数组 [k, n-1] 位置的元素,如果当前元素小于堆顶元素也就是位置为 0 的元素,那么删除堆顶元素并将当前元素插入堆中。

最后,堆中的 K 个元素即为所求。

class Solution {
public:
    vector<int> GetLeastNumbers_Solution(vector<int> input, int k) {
        
        if (k > input.size())
        {
            vector<int> result;
            return result;
        }
        Build_Heap_K(input, k);
        vector<int> result(input.begin(), input.begin()+k);
        return result;
    }
    
    void Build_Heap_K(vector<int> &input, int k)
    {
        // input 的 [0, k-1] 作为一个大小为 K 的大顶堆
        // 然后从上往下进行堆化
        // 也就是对堆中 [0, (k-2)/2] 的节点进行堆化
        for (int i = (k-2)/2; i >= 0; i--)
            Heapify(input, k, i);
        
        // 遍历 input 的 [k, n-1] 的元素
        // 如果某元素小于堆顶值,将其插入堆中
        // 也即将其替换为堆顶元素,堆化之
        for (int i = k; i < input.size(); i++)
        {
            if (input[i] < input[0])
            {
                input[0] = input[i];
                Heapify(input, k, 0);
            }
        }
    }
    
    void Heapify(vector<int> &input, int k, int i)
    {
        while(1)
        {
            int max_pos = i;
            if (2*i+1 < k && input[2*i+1] > input[max_pos])
                max_pos = 2 * i + 1;
            if (2*i+2 < k && input[2*i+2] > input[max_pos])
                max_pos = 2 * i + 2;
            if (max_pos == i)
                   break;
            else
            {
                int temp = input[max_pos];
                input[max_pos] = input[i];
                input[i] = temp;
            }
            i = max_pos;
        }
    }
};
2.2. 方法二——快排分治

可参考 LeetCode 215——数组中的第 K 个最大元素

快排的时候需要分区,分区点左边的元素都小于主元,分区点右边的元素都大于主元。如果分区后主元的位置恰好为 K,那左边正好是最小的 K 个数;如果大于 K,我们需要递归在左边找到第 K 个位置;如果小于 K,我们则需要递归在右边找到第 K 个位置。

class Solution {
public:
    vector<int> GetLeastNumbers_Solution(vector<int> input, int k) {
        int n = input.size();
        if (k > n)
        {
            vector<int> result;
            return result;
        }
        Quick_Sort(input, 0, n-1, k);
        vector<int> result(input.begin(), input.begin()+k);
        return result;
    }
    
    void Quick_Sort(vector<int> &input, int left, int right, int k)
    {
        if (left < right)
        {
            int pivot = input[right];
            int i = left;
            int j = left;
            
            for (; j < right; j++)
            {
                if (input[j] < pivot)
                {
                    int temp = input[i];
                    input[i] = input[j];
                    input[j] = temp;
                    i++;
                }
            }
            input[j] = input[i];
            input[i] = pivot;
            
            if (i == k)    return;
            else if (i > k)    Quick_Sort(input, left, i-1, k);
            else Quick_Sort(input, i+1, right, k);
        }
    }
};

获取更多精彩,请关注「seniusen」!