说明

这是很常见的 TopK 考题,有 2 种方法:

  • 方法一:快速排序的 parititon,这种方法需要一下子把数据全部读入内存;
  • 方法二:优先队列,可以应对数据量很大的情况。

方法一:减治思想(逐渐缩小搜索区间)

必需要会的知识点:快速排序的 parititon

parititon:遍历一次区间,将数组根据 pivot 划分成两个部分,左边的部分 <= pivot,右边的部分 > pivot

参考代码 1

import java.util.ArrayList;

public class Solution {

    public ArrayList<Integer> GetLeastNumbers_Solution(int[] input, int k) {
        if (k == 0) {
            return new ArrayList<>();
        }

        int len = input.length;
        int left = 0;
        int right = len - 1;

        // 找最小的 K 个数,即找下标 k - 1
        while (left <= right) {
            int index = partition(input, left, right);
            if (index == k - 1) {
                ArrayList<Integer> res = new ArrayList<>();
                for (int i = 0; i < k; i++) {
                    res.add(input[i]);
                }
                return res;
            } else if (index > k - 1) {
                right = index - 1;
            } else {
                left = index + 1;
            }
        }
        return null;
    }

    private int partition(int[] input, int left, int right) {
        // input[left + 1..le] <= pivot
        // input(le..i) > pivot

        int pivot = input[left];
        int le = left;
        for (int i = left + 1; i <= right; i++) {
            if (input[i] <= pivot) {
                le++;
                swap(input, le, i);
            }
        }
        swap(input, left, le);
        return le;
    }

    private void swap(int[] input, int index1, int index2) {
        int temp = input[index1];
        input[index1] = input[index2];
        input[index2] = temp;
    }
}

复杂度分析

  • 时间复杂度:,这里 是输入数组的长度;
  • 空间复杂度:,只使用到常数个变量。

说明:有学习过快速排序的朋友们一定知道,快速排序的 pivot 必需随机选择,否则在面对有序数组(顺序或者逆序)的时候,划分数组的区间是偏斜的,此时快速排序的时间复杂度为 ,为了避免这种情况,需要随机选择 pivot

「参考代码 2」在「参考代码 1」的基础上,在 partition 方法的内部,只加了两行代码。

int randomIndex = left + random.nextInt(right - left + 1);
swap(input, left, randomIndex);

参考代码 2

import java.util.ArrayList;
import java.util.Random;

public class Solution {

    private static final Random random = new Random();

    public ArrayList<Integer> GetLeastNumbers_Solution(int[] input, int k) {
        if (k == 0) {
            return new ArrayList<>();
        }

        int len = input.length;
        int left = 0;
        int right = len - 1;

        // 找最小的 K 个数,即找下标 k - 1
        while (left <= right) {
            int index = partition(input, left, right);
            if (index == k - 1) {
                ArrayList<Integer> res = new ArrayList<>();
                for (int i = 0; i < k; i++) {
                    res.add(input[i]);
                }
                return res;
            } else if (index > k - 1) {
                right = index - 1;
            } else {
                left = index + 1;
            }
        }
        return null;
    }

    private int partition(int[] input, int left, int right) {
        int randomIndex = left + random.nextInt(right - left + 1);
        swap(input, left, randomIndex);

        // input[left + 1..le] <= pivot
        // input(le..i) > pivot
        int pivot = input[left];
        int le = left;
        for (int i = left + 1; i <= right; i++) {
            if (input[i] <= pivot) {
                le++;
                swap(input, le, i);
            }
        }
        swap(input, left, le);
        return le;
    }

    private void swap(int[] input, int index1, int index2) {
        int temp = input[index1];
        input[index1] = input[index2];
        input[index2] = temp;
    }
}

复杂度分析

  • 时间复杂度:,这里 是输入数组的长度;
  • 空间复杂度:,只使用到常数个变量。

方法二:使用优先队列(堆)

优先队列(堆)专门用于动态选出最值,选最小的 k 个元素。这里可以选择最小堆:

  • 堆顶的元素是堆里的最大的元素;
  • 当放入堆的元素超过 k 个,即恰好 k + 1 个的时候,把堆顶元素弹出。

通过这种方式,遍历完数组中的所有元素,堆中所有的元素就是数组里最小的 k 个元素。

import java.util.ArrayList;
import java.util.Collections;
import java.util.PriorityQueue;

public class Solution {

    public ArrayList<Integer> GetLeastNumbers_Solution(int[] input, int k) {
        if (k == 0) {
            return new ArrayList<>();
        }

        // 保存 k + 1 个元素
        PriorityQueue<Integer> minHeap = new PriorityQueue<>(Collections.reverseOrder());
        for (int num : input) {
            minHeap.add(num);
            if (minHeap.size() > k) {
                minHeap.poll();
            }
        }

        ArrayList<Integer> res = new ArrayList<>();
        for (int num : minHeap) {
            res.add(num);
        }
        return res;
    }
}

复杂度分析

  • 时间复杂度:,这里 是输入数组的长度;
  • 空间复杂度:,优先队列中最多存放 个元素。