import java.util.*;

public class Solution {
    public ArrayList<Integer> GetLeastNumbers_Solution(int [] input, int k) {
        if (input == null || input.length == 0 || k == 0 || input.length < k) {
            return new ArrayList<>();
        }
        ArrayList<Integer> res = new ArrayList<>();
        int start = 0, end = input.length - 1;
        int index = partition(input, start, end);
        while (index != k - 1) {
            if (index < k - 1) {
                start = index + 1;
            } else {
                end = index - 1;
            }
            index = partition(input, start, end);
        }
        for (int i = 0; i <= index; i++) {
            res.add(input[i]);
        }
        return res;
    }

    private static int partition(int[] input, int start, int end) {
        int pivot = input[start];
        while (start < end) {
            while (start < end && input[end] >= pivot) {
                end--;
            }
            swap(input, start, end);
            while (start < end && input[start] < pivot) {
                start++;
            }
            swap(input, start, end);
        }
        return start;
    }

    private static void swap(int[] input, int left, int right) {
        int tmp = input[left];
        input[left] = input[right];
        input[right] = tmp;
    }
}