给定一个 n x n 矩阵,其中每行和每列元素均按升序排序,找到矩阵中第 k 小的元素。
请注意,它是排序后的第 k 小元素,而不是第 k 个不同的元素。

最先想到的是堆排序发,但是这道题还有复杂度更好的算法,二分法,结合了剑指offer上矩阵查找的问题。

方法一、 直接比较法

将二维数组里的数据放入一未数组,然后对一维数组进行排序,返回第k-1个元素。

class Solution {
public:
    int kthSmallest(vector<vector<int>>& matrix, int k) {
        vector<int> rec;
        for (auto& row : matrix) {
            for (int it : row) {
                rec.push_back(it);
            }
        }
        sort(rec.begin(), rec.end());
        return rec[k - 1];
    }
};

方法二、 堆排序法

使用最大堆

class Solution {
public:
    int kthSmallest(vector<vector<int>>& matrix, int k) {
        priority_queue<int> pq;
        for (int i = 0; i < matrix.size(); ++i) {
            for (int j = 0; j < matrix[0].size(); ++j) {
                pq.push(matrix[i][j]);
                if (pq.size() > k) pq.pop();
            }
        }
        return pq.top();
    }
};

方法三、二分查找

参考:https://leetcode-cn.com/problems/kth-smallest-element-in-a-sorted-matrix/solution/you-xu-ju-zhen-zhong-di-kxiao-de-yuan-su-by-leetco/

思路:
整个二维数组中 matrix[0][0]为最小值,matrix[n−1][n−1] 为最大值,现在我们将其分别记作 l 和 r。

可以发现一个性质:任取一个数 mid 满足 l小于等于mid小于等于r,那么矩阵中不大于mid 的数,肯定全部分布在矩阵的左上角。而且需要注意的是mid不一定会存在矩阵中,它只是一个判断的边界条件。当left和right逐渐靠近且相等时,输出。
图片说明

计算矩阵中有多少数不大于 midmid :
如果数量不少于 k,那么说明最终答案 x不大于 mid;
如果数量少于 k,那么说明最终答案 x 大于 mid。

如题解中所说,如果 num >= k,那么说明最终答案 x <= mid;如果 num < k,那么说明最终答案 x > mid。 在最后一次迭代时,check返回的结果为false,即 num < k,说明 x > mid,又因为 x <= right。当 left = mid + 1 后,left > righ,while循环结束。此时有 mid < x <= right < left = mid + 1,即 mid < x <= mid + 1。可得 x = mid + 1 = left。
也就是说,最后的判断条件一定是left != right 即left == right 时 输出结果,输出left和right都是一样的。

class Solution {
   public:
    bool check(vector<vector<int>>& matrix, int mid, int k, int n) {
        int i = n - 1;
        int j = 0;
        int num = 0;
        while (i >= 0 && j < n) {
            if (matrix[i][j] <= mid) {
                num += i + 1;
                j++;
            } else {
                i--;
            }
        }
        return num >= k;//返回true,否则返回false
    }

    int kthSmallest(vector<vector<int>>& matrix, int k) {
        int n = matrix.size();
        int left = matrix[0][0];
        int right = matrix[n - 1][n - 1];
        while (left < right) {
            int mid = left + ((right - left) >> 1);
            if (check(matrix, mid, k, n)) {
                right = mid;
            } else {
                left = mid + 1;
            }
        }
        return left;
    }
};