题目描述

输入整数数组 arr ,找出其中最小的 k 个数。例如,输入4、5、1、6、2、7、3、8这8个数字,则最小的4个数字是1、2、3、4。

示例1

输入:arr = [3,2,1], k = 2
输出:[1,2] 或者 [2,1]

示例2

输入:arr = [0,1,2,1], k = 1
输出:[0]

解题思路

  1. 最简单的方法,先排序后选前 k 个元素。用选择排序,只选 k 个最小的就结束排序了。
  2. 最大堆方法,当堆的元素个数超过 k 时,将最大的值弹出,最后堆中只保留了 k 个最小值。
  3. 快速排序法,如果 pivot 刚好在第 k 个位置,那么 0 ~ k 的值是 k 个最小值。如果 pivot > k,那么就排序 0 ~ k - 1。如果 pivot < k,那么就排序 k + 1 到数组末尾。

Java代码实现

  1. 选择排序
class Solution {
    public int[] getLeastNumbers(int[] arr, int k) {
        selectMin(arr, k);
        return Arrays.copyOf(arr, k);
    }

    private void selectMin(int[] arr, int k) {
        for (int i = 0, j = 0; i < k && j < arr.length; ++i, ++j) {
            int minIndex = i;
            for (int m = j + 1; m < arr.length; ++m) {
                if (arr[m] < arr[minIndex]) {
                    minIndex = m;
                }
            }
            swap(arr, j, minIndex);
        }
    }

    private void swap(int[] arr, int arg0, int arg1) {
        int temp = arr[arg0];
        arr[arg0] = arr[arg1];
        arr[arg1] = temp;
    }
}
  1. 最大堆方法
class Solution {
    public int[] getLeastNumbers(int[] arr, int k) {
        if (k <= 0) return new int[0];
        // PriorityQueue 默认使用最小堆。将其改为最大堆。
        PriorityQueue<Integer> queue = new PriorityQueue<>(k, (o1, o2) -> {
            return Integer.compare(o2, o1);
        });

        for (int i = 0; i < arr.length; ++i) {
            // 若堆中元素个数小于 k,或当前元素比堆中最大元素要小,则将当前元素加入堆中
            if (queue.size() < k || arr[i] < queue.peek()) queue.offer(arr[i]);
            // 加入完元素后如果个数超过 k,将最大值删除
            if (queue.size() > k) queue.poll();
        }

        int[] res = new int[queue.size()];
        for (int i = 0; i < res.length; ++i) {
            res[i] = queue.poll();
        }
        return res;
    }
}
  1. 快速排序法
class Solution {
    public int[] getLeastNumbers(int[] arr, int k) {
        if (k <= 0) return new int[0];
        quickSort(arr, 0, arr.length - 1, k);
        int[] res = new int[k];
        for (int i = 0; i < k; ++i) {
            res[i] = arr[i];
        }
        return res;
    }

    // 筛选排序的位置
    public void quickSort(int[] arr, int left, int right, int k) {
        if (left >= right) return;
        int pivot = sort(arr, left, right);
        if (pivot == k) return;
        else if (pivot < k) quickSort(arr, pivot + 1, right, k);
        else quickSort(arr, left, pivot - 1, k);
    }

    // 快速排序
    public int sort(int[] arr,int left, int right) {
        int temp = arr[left];
        while (left < right) {
            while (left < right && arr[right] > temp) --right;
            if (left < right) arr[left++] = arr[right];
            while (left < right && arr[left] < temp) ++left;
            if (left < right) arr[right--] = arr[left];
        }
        arr[left] = temp;
        return left;
    }
}