我说想要解决TopK问题,首先的话,你需要去熟练掌握两种排序算法,①、快速排序,②、堆排序。
快速排序
快速排序的基本思想:
- 先从数列中取出一个数作为基准数
- 分区过程,将比这个数大的数全放到它的右边,小于或等于它的数全放到它的左边(或者相反,根据需要升序还是降序来)
- 再对左右区间重复第二步,直到各区间只有一个数
import java.util.Arrays;
public class sorts {
public static void quickSort(int[] arr, int begin, int end) {
if (begin < end) {
int mid = getMiddle(arr, begin, end);
quickSort(arr, begin, mid);
quickSort(arr, mid + 1, end);
}
}
private static int getMiddle(int[] arr, int begin, int end) {
int mid = arr[begin];
int left = begin;
int right = end;
while (left < right) {
while (left < right && mid <= arr[right]) {
right--;
}
arr[left] = arr[right];
while (left < right && mid >= arr[left]) {
left++;
}
arr[right] = arr[left];
}
arr[left] = mid;
return left;
}
public static void main(String[] args) {
int[] arr = {1, 2, 3, 4, 5, 6, 7, 6, 5, 4, 3, 2, 1};
System.out.print("排序前:");
System.out.println(Arrays.toString(arr));
quickSort(arr, 0, arr.length - 1);
System.out.print("排序后:");
System.out.println(Arrays.toString(arr));
}
}
快排优化(三分取中法)
对于一些特殊的数据,比如整体上来说几乎有序的数据,在使用普通的快排会非常慢,此时我们就需要优化!
采用三数取中法,也就是取左端、中间、右端三个数,然后进行排序,确保中间值最小,然后将中间数作为枢纽值。
public class Main {
public void quickSort(int[] arr, int k,int start,int end) {
if (start < end) {
int mid = getMiddle(arr,start,end);
quickSort(arr,k,start,mid - 1);
quickSort(arr,k,mid + 1,end);
}
}
private int getMiddle(int[] arr, int start, int end) {
int mid = start + (end - start) / 2;
if (arr[start] > arr[end])
swap(arr, start, end);
// 保证中间较小
if (arr[mid] > arr[end])
swap(arr, mid, end);
// 保证中间最小,左右最大
if (arr[mid] > arr[start])
swap(arr, start, mid);
int left = start;
int right = end;
int pivot = arr[left];
while (left < right) {
while(left < right && arr[right] >= pivot) {
right--;
}
arr[left] = arr[right];
while (left < right && arr[left] <= pivot) {
left++;
}
arr[right] = arr[left];
}
arr[left] = pivot;
return left;
}
private void swap(int[] arr, int i, int j) {
int tmp = arr[i];
arr[i] = arr[j];
arr[j] = tmp;
}
}
堆排序
堆排序的思想
- 先将待排序的序列建成大根堆,使得每个父节点的元素大于等于它的子节点。
- 此时整个序列最大值即为堆顶元素,我们将其与末尾元素交换,使末尾元素为最大值。
- 然后再调整堆顶元素使得剩下的 n-1 个元素仍为大根堆,再重复执行以上操作我们即能得到一个有序的序列。
小顶堆与之相反
堆实际上是一棵完全二叉树,其任何一非叶节点满足性质:
heap[i] <= heap[2i+1] && heap[i] <= heap[2i+2] 或者 heap[i] >= heap[2i+1] && heap >= heap[2i+2]
即任何一非叶节点的关键字不大于或者不小于其左右孩子节点的关键字。
小顶堆
public class Main {
public static void heapSort(int[] arr,int heapSize) {
//上浮
for (int i = heapSize / 2 - 1; i >= 0; i--) {
builderHeap(arr,i,arr.length);
}
//下沉
for (int i = heapSize - 1; i >= 0; i--) {
swap(arr,0,i);
builderHeap(arr,0,i);
}
}
private static void swap(int[] arr, int i, int j) {
int tmp = arr[i];
arr[i] = arr[j];
arr[j] = tmp;
}
private static void builderHeap(int[] arr, int index, int length) {
//当前节点
int tmp = arr[index];
//左子节点
for (int i = index * 2 + 1; i < length; i = i * 2 + 1) {
//如果右子节点值大于左子节点
if (i + 1 < length && arr[i + 1] > arr[i]) {
i++;
}
//如果左子节点和右子节点的最大值大于父节点,则进行交换
if (arr[i] > tmp) {
arr[index] = arr[i];
index = i;
}else
break;
}
arr[index] = tmp;
}
public static void main(String[] args) {
int[] a = {1,2,4,0,3,-1,6,2};
System.out.println("堆排序前:" + Arrays.toString(a));
heapSort(a,a.length);
System.out.println("堆排序后:" + Arrays.toString(a));
}
}
大顶堆
public class sorts {
public static void heapSort(int[] arr,int heapSize) {
for (int i = heapSize / 2 - 1; i >= 0; i--) {
builderHeap(arr,i,arr.length);
}
for (int i = heapSize - 1; i >= 0; i--) {
swap(arr,0,i);
builderHeap(arr,0,i);
}
}
private static void swap(int[] arr, int i, int j) {
int tmp = arr[i];
arr[i] = arr[j];
arr[j] = tmp;
}
private static void builderHeap(int[] arr, int index, int length) {
int tmp = arr[index];
for (int i = index * 2 + 1; i < length; i = i * 2 + 1) {
if (i + 1 < length && arr[i + 1] < arr[i]) {
i++;
}
if (arr[i] < tmp) {
arr[index] = arr[i];
index = i;
}else
break;
}
arr[index] = tmp;
}
public static void main(String[] args) {
int[] a = {3,2,1,5,6,4};
System.out.println("堆排序前:" + Arrays.toString(a));
heapSort(a,a.length);
System.out.println("堆排序后:" + Arrays.toString(a));
}
}
解决TopK问题
利用快速排序分区思想找第 K 小元素
import java.util.Arrays;
/**
* @author dong
* @date 2021/4/8 19:41
*/
public class sorts {
public static int getTopK(int[] arr, int k) {
quickSort(arr, 0, arr.length - 1, k - 1);
return arr[k - 1];
}
public static void quickSort(int[] arr, int begin, int end, int k) {
// 每快排切分1次,找到排序后下标为 mid 的元素,如果 mid 恰好等于 k 就返回 mid 以及 mid 左边所有的数;
int mid = getMiddle(arr, begin, end);
if (mid == k) {
System.out.println(Arrays.toString(arr));
return;
}
//根据 mid 和 k 的大小确定继续切分左段还是右段。
if (k > mid) {
quickSort(arr, mid + 1, end, k);
} else
quickSort(arr, begin, mid - 1, k);
}
private static int getMiddle(int[] arr, int begin, int end) {
int mid = arr[begin];
int left = begin;
int right = end;
while (left < right) {
while (left < right && mid <= arr[right]) {
right--;
}
arr[left] = arr[right];
while (left < right && mid >= arr[left]) {
left++;
}
arr[right] = arr[left];
}
arr[left] = mid;
return left;
}
public static void main(String[] args) {
int[] arr = {1, 2, 3, 4, 5, 6, 7, 6, 5, 4, 3, 2, 1};
int topK = getTopK(arr, 3);
System.out.println("第3大元素为:" + topK);
}
}
快排获取最小的 K 个元素
public class Main {
public static int[] getLeastNumbers(int[] arr, int k) {
int len = arr.length;
if (len == 0 || k == 0) return new int[0];
//k - 1就是第k个元素的下标,我们要返回前k个元素
quickSort(arr,0,arr.length - 1,k);
int[] res = new int[k];
for(int i = 0;i < k;i++) {
res[i] = arr[i];
}
return res;
}
public static void quickSort(int[] arr,int start,int end,int k) {
if (start < end) {
int mid = getMiddle(arr,start,end);
if (mid == k) {
return;
}
//根据 mid 和 k 的大小确定继续切分左段还是右段。
if (k > mid) {
quickSort(arr, mid + 1, end, k);
} else
quickSort(arr, start, mid - 1, k);
}
}
private static int getMiddle(int[] arr, int start, int end) {
int mid = start + (end - start) / 2;
if (arr[start] > arr[end]) {
swap(arr,start,end);
}
if (arr[mid] > arr[end]) {
swap(arr,mid,end);
}
if (arr[mid] > arr[start]) {
swap(arr,start,mid);
}
int pivot = arr[start];
int left = start;
int right = end;
while (left < right) {
while (left < right && arr[right] >= pivot) {
right--;
}
arr[left] = arr[right];
while (left < right && arr[left] <= pivot) {
left++;
}
arr[right] = arr[left];
}
arr[left] = pivot;
return left;
}
private static void swap(int[] arr, int i, int j) {
int tmp = arr[i];
arr[i] = arr[j];
arr[j] = tmp;
}
public static void main(String[] args) {
int[] arr = {2,5,4,10,1,0,6,9,8,-2};
System.out.println(Arrays.toString(arr));
int k = 4;
int[] res = getLeastNumbers(arr, k);
System.out.println("最小的" + k + "个元素:" + Arrays.toString(res));
}
}
构造固定堆解决最小K个元素
要获取最大K个元素,我们来构建大顶堆
public class Main {
private static int[] topK(int[] data, int k) {
int[] topK = new int[k];
//构造固定大小堆
for (int i = 0; i < k; i++) {
topK[i] = data[i];
}
buildHeap(topK);
for (int i = k; i < data.length; i++) {
int root = topK[0];
//如果比堆顶元素小
if (data[i] < root) {
topK[0] = data[i];
heapify(topK,0,topK.length);
}
}
return topK;
}
private static void buildHeap(int[] data) {
//从最后一个父节点的下标开始遍历 子推父:(data.length - 1 - 1)/2
int heapSize = data.length;
for (int i = heapSize / 2 - 1; i >= 0; i--) {
heapify(data,i,heapSize);
}
}
private static void heapify(int[] arr,int index,int len) {
int tmp = arr[index];
for (int i = index * 2 + 1; i < len; i = i * 2 + 1) {
if (i + 1 < len && arr[i + 1] > arr[i]) {
i += 1;
}
if (arr[i] > tmp) {
arr[index] = arr[i];
index = i;
}else
break;
}
arr[index] = tmp;
}
//测试
public static void main(String[] args) {
int[] data = {12, 10, 4, 7, 30, 9, 6, 20};
int[] topK = topK(data, 3);
System.out.println(Arrays.toString(topK));
}
}
构造固定堆解决最大K个元素
要获取最大K个元素,我们来构建小顶堆
public class MinHeap {
private static int[] topK(int[] data, int k) {
int[] topK = new int[k];
//构造固定大小堆
for (int i = 0; i < k; i++) {
topK[i] = data[i];
}
buildHeap(topK);
for (int i = k; i < data.length; i++) {
int root = topK[0];
//如果比堆顶元素大
if (data[i] > root) {
topK[0] = data[i];
//重新构建堆
heapify(topK,0,topK.length);
}
}
return topK;
}
private static void buildHeap(int[] data) {
//从最后一个父节点的下标开始遍历 子推父:(data.length - 1 - 1)/2
int heapSize = data.length;
for (int i = heapSize / 2 - 1; i >= 0; i--) {
heapify(data,i,heapSize);
}
}
private static void heapify(int[] arr,int index,int len) {
int tmp = arr[index];
for (int i = index * 2 + 1; i < len; i = i * 2 + 1) {
if (i + 1 < len && arr[i + 1] < arr[i]) {
i += 1;
}
if (arr[i] < tmp) {
arr[index] = arr[i];
index = i;
}else
break;
}
arr[index] = tmp;
}
//测试
public static void main(String[] args) {
int[] data = {12, 10, 4, 7,11, 30, 9, 6, 20};
int[] topK = topK(data, 5);
System.out.println(Arrays.toString(topK));
}
}