题目描述
输入整数数组 arr ,找出其中最小的 k 个数。例如,输入4、5、1、6、2、7、3、8这8个数字,则最小的4个数字是1、2、3、4。
示例1
输入:arr = [3,2,1], k = 2
输出:[1,2] 或者 [2,1]
示例2
输入:arr = [0,1,2,1], k = 1
输出:[0]
解题思路
- 最简单的方法,先排序后选前 k 个元素。用选择排序,只选 k 个最小的就结束排序了。
- 最大堆方法,当堆的元素个数超过 k 时,将最大的值弹出,最后堆中只保留了 k 个最小值。
- 快速排序法,如果 pivot 刚好在第 k 个位置,那么 0 ~ k 的值是 k 个最小值。如果 pivot > k,那么就排序 0 ~ k - 1。如果 pivot < k,那么就排序 k + 1 到数组末尾。
Java代码实现
- 选择排序
class Solution { public int[] getLeastNumbers(int[] arr, int k) { selectMin(arr, k); return Arrays.copyOf(arr, k); } private void selectMin(int[] arr, int k) { for (int i = 0, j = 0; i < k && j < arr.length; ++i, ++j) { int minIndex = i; for (int m = j + 1; m < arr.length; ++m) { if (arr[m] < arr[minIndex]) { minIndex = m; } } swap(arr, j, minIndex); } } private void swap(int[] arr, int arg0, int arg1) { int temp = arr[arg0]; arr[arg0] = arr[arg1]; arr[arg1] = temp; } }
- 最大堆方法
class Solution { public int[] getLeastNumbers(int[] arr, int k) { if (k <= 0) return new int[0]; // PriorityQueue 默认使用最小堆。将其改为最大堆。 PriorityQueue<Integer> queue = new PriorityQueue<>(k, (o1, o2) -> { return Integer.compare(o2, o1); }); for (int i = 0; i < arr.length; ++i) { // 若堆中元素个数小于 k,或当前元素比堆中最大元素要小,则将当前元素加入堆中 if (queue.size() < k || arr[i] < queue.peek()) queue.offer(arr[i]); // 加入完元素后如果个数超过 k,将最大值删除 if (queue.size() > k) queue.poll(); } int[] res = new int[queue.size()]; for (int i = 0; i < res.length; ++i) { res[i] = queue.poll(); } return res; } }
- 快速排序法
class Solution { public int[] getLeastNumbers(int[] arr, int k) { if (k <= 0) return new int[0]; quickSort(arr, 0, arr.length - 1, k); int[] res = new int[k]; for (int i = 0; i < k; ++i) { res[i] = arr[i]; } return res; } // 筛选排序的位置 public void quickSort(int[] arr, int left, int right, int k) { if (left >= right) return; int pivot = sort(arr, left, right); if (pivot == k) return; else if (pivot < k) quickSort(arr, pivot + 1, right, k); else quickSort(arr, left, pivot - 1, k); } // 快速排序 public int sort(int[] arr,int left, int right) { int temp = arr[left]; while (left < right) { while (left < right && arr[right] > temp) --right; if (left < right) arr[left++] = arr[right]; while (left < right && arr[left] < temp) ++left; if (left < right) arr[right--] = arr[left]; } arr[left] = temp; return left; } }