寻找两个正序数组的中位数

一、题目描述 (注意:不去重!!!)

给定两个大小为 m 和 n 的正序(从小到大)数组 n u m s 1 nums1 nums1 n u m s 2 nums2 nums2
请你找出这两个正序数组的中位数,并且要求算法的时间复杂度为 O ( l o g ( m + n ) ) O(log(m + n)) O(log(m+n))

你可以假设 nums1 和 nums2 不会同时为空。

示例 1:

nums1 = [1, 3]
nums2 = [2]
则中位数是 2.0

示例 2:

nums1 = [1,2,3]
nums2 = [1,2,3,4,5,6,7,8,9]
num = [1,1,2,2,3,   3,4,   5,6,7,8,9]
则中位数是 (3 + 4)/2 = 3.5

二、解题思路 & 代码

题目的要求 O ( l o g ( m + n ) O(log(m+n) O(log(m+n)。看到 l o g log log,很明显,我们只有用到二分的方法才能达到。

把问题转化为 在两个数组中找第 k k k 个大的数,求中位数就是求 第 k k k 小数的一种特殊情况(比如,两个数组总长度为 14,那么 k = ( 14 + 1 ) / / 2 = 7 k = (14+1) // 2 = 7 k=(14+1)//2=7

假设我们要找第 k k k 小数,我们可以每次循环排除掉 k / 2 k/2 k/2 个数

A \bf A A: A [ 1 ] , A [ 2 ] , A [ 3 ] , A [ k / 2 ] . . . A[1] ,A[2] ,A[3],A[k/2] ... A[1]A[2]A[3]A[k/2]...
B \bf B B: B [ 1 ] , B [ 2 ] , B [ 3 ] , B [ k / 2 ] . . . B[1],B[2],B[3],B[k/2] ... B[1]B[2]B[3]B[k/2]...

如果 A [ k / 2 ] < B [ k / 2 ] A[k/2]<B[k/2] A[k/2]<B[k/2] ,那么 A [ 1 ] , A [ 2 ] , A [ 3 ] , A [ k / 2 ] A[1],A[2],A[3],A[k/2] A[1]A[2]A[3]A[k/2]都不可能是第 k k k 小的数字。

A A A 数组中比 A [ k / 2 ] A[k/2] A[k/2] 小的数有 k / 2 − 1 k/2-1 k/21 个, B B B 数组中, B [ k / 2 ] < A [ k / 2 ] B[k/2] < A[k/2] B[k/2]<A[k/2] ,假设 B [ k / 2 ] B[k/2] B[k/2] 前边的数字都比 A [ k / 2 ] A[k/2] A[k/2] 小,也只有 k / 2 − 1 k/2-1 k/21 个,所以比 A [ k / 2 ] A[k/2] A[k/2] 小的数字最多有 k / 2 − 1 + k / 2 − 1 = k − 2 k/2-1+k/2-1=k-2 k/21+k/21=k2 个,所以 A [ k / 2 ] A[k/2] A[k/2] 最多是第 k − 1 k-1 k1 小的数。而比 A [ k / 2 ] A[k/2] A[k/2] 小的数更不可能是第 k k k 小的数了,所以可以把它们排除。

class Solution:
    def findMedianSortedArrays(self, nums1: List[int], nums2: List[int]) -> float:
        n = len(nums1)
        m = len(nums2)
        left = (n + m + 1) // 2
        right = (n + m +2) // 2

        def getKth(nums1, start1, end1, nums2, start2, end2, k):
            L1 = end1 - start1 + 1
            L2 = end2 - start2 + 1
            # 让 L1 长度小于 L2,这样如果有数组为空,则一定是 L1
            if (L1 > L2): return getKth(nums2, start2, end2, nums1, start1, end1, k)
            if (L1 == 0): return nums2[start2 + k - 1]

            if (k == 1): return min(nums1[start1], nums2[start2])
            
            i = start1 + min(L1, k // 2) - 1
            j = start2 + min(L2, k // 2) - 1

            if (nums1[i] > nums2[j]):
                return getKth(nums1, start1, end1, nums2, j + 1, end2, k - (j - start2 + 1))
            else:
                return getKth(nums1, i+1, end1, nums2, start2, end2, k - (i - start1 + 1))

        # 将偶数和奇数情况合并,如果是奇数,会求同样的 k
        return (getKth(nums1, 0, n - 1, nums2, 0, m - 1, left) + \
               getKth(nums1, 0, n - 1, nums2, 0, m - 1, right)) * 0.5

复杂度分析:

  1. 时间复杂度: O ( l o g ( m + n ) O(log(m+n) O(log(m+n)。每进行一次循环,我们就减少 k/2 个元素,所以时间复杂度是 O(log(k),而 k=(m+n)/2,所以最终的复杂也就是 O ( l o g ( m + n ) O(log(m+n) O(log(m+n)

  2. 空间复杂度: O ( 1 ) O(1) O(1) 。虽然我们用到了递归,但是可以看到这个递归属于尾递归,所以编译器不需要不停地堆栈,所以空间复杂度为 O(1)

另外附上 LeetCode 官方题解:

class Solution:
    def findMedianSortedArrays(self, nums1: List[int], nums2: List[int]) -> float:
        def getKthElement(k):
            """ - 主要思路:要找到第 k (k>1) 小的元素,那么就取 pivot1 = nums1[k/2-1] 和 pivot2 = nums2[k/2-1] 进行比较 - 这里的 "/" 表示整除 - nums1 中小于等于 pivot1 的元素有 nums1[0 .. k/2-2] 共计 k/2-1 个 - nums2 中小于等于 pivot2 的元素有 nums2[0 .. k/2-2] 共计 k/2-1 个 - 取 pivot = min(pivot1, pivot2),两个数组中小于等于 pivot 的元素共计不会超过 (k/2-1) + (k/2-1) <= k-2 个 - 这样 pivot 本身最大也只能是第 k-1 小的元素 - 如果 pivot = pivot1,那么 nums1[0 .. k/2-1] 都不可能是第 k 小的元素。把这些元素全部 "删除",剩下的作为新的 nums1 数组 - 如果 pivot = pivot2,那么 nums2[0 .. k/2-1] 都不可能是第 k 小的元素。把这些元素全部 "删除",剩下的作为新的 nums2 数组 - 由于我们 "删除" 了一些元素(这些元素都比第 k 小的元素要小),因此需要修改 k 的值,减去删除的数的个数 """
            
            index1, index2 = 0, 0
            while True:
                # 特殊情况
                if index1 == m:
                    return nums2[index2 + k - 1]
                if index2 == n:
                    return nums1[index1 + k - 1]
                if k == 1:
                    return min(nums1[index1], nums2[index2])

                # 正常情况
                newIndex1 = min(index1 + k // 2 - 1, m - 1)
                newIndex2 = min(index2 + k // 2 - 1, n - 1)
                pivot1, pivot2 = nums1[newIndex1], nums2[newIndex2]
                if pivot1 <= pivot2:
                    k -= newIndex1 - index1 + 1
                    index1 = newIndex1 + 1
                else:
                    k -= newIndex2 - index2 + 1
                    index2 = newIndex2 + 1
        
        m, n = len(nums1), len(nums2)
        totalLength = m + n
        if totalLength % 2 == 1:
            return getKthElement((totalLength + 1) // 2)
        else:
            return (getKthElement(totalLength // 2) + getKthElement(totalLength // 2 + 1)) / 2

参考:

  1. LeetCode题解

==============================================================================================================================

两个有序数组找第k大

解题思路 & 代码

其实就是上一个问题的子问题

def getKthElement(nums1, nums2, k):
    """ - 主要思路:要找到第 k (k>1) 小的元素,那么就取 pivot1 = nums1[k/2-1] 和 pivot2 = nums2[k/2-1] 进行比较 - 这里的 "/" 表示整除 - nums1 中小于等于 pivot1 的元素有 nums1[0 .. k/2-2] 共计 k/2-1 个 - nums2 中小于等于 pivot2 的元素有 nums2[0 .. k/2-2] 共计 k/2-1 个 - 取 pivot = min(pivot1, pivot2),两个数组中小于等于 pivot 的元素共计不会超过 (k/2-1) + (k/2-1) <= k-2 个 - 这样 pivot 本身最大也只能是第 k-1 小的元素 - 如果 pivot = pivot1,那么 nums1[0 .. k/2-1] 都不可能是第 k 小的元素。把这些元素全部 "删除",剩下的作为新的 nums1 数组 - 如果 pivot = pivot2,那么 nums2[0 .. k/2-1] 都不可能是第 k 小的元素。把这些元素全部 "删除",剩下的作为新的 nums2 数组 - 由于我们 "删除" 了一些元素(这些元素都比第 k 小的元素要小),因此需要修改 k 的值,减去删除的数的个数 """
    m = len(nums1)
    n = len(nums2)
    index1, index2 = 0, 0
    while True:
        # 特殊情况
        if index1 == m:
            return nums2[index2 + k - 1]
        if index2 == n:
            return nums1[index1 + k - 1]
        if k == 1:
            return min(nums1[index1], nums2[index2])

        # 正常情况
        newIndex1 = min(index1 + k // 2 - 1, m - 1)
        newIndex2 = min(index2 + k // 2 - 1, n - 1)
        pivot1, pivot2 = nums1[newIndex1], nums2[newIndex2]
        if pivot1 <= pivot2:
            k -= newIndex1 - index1 + 1
            index1 = newIndex1 + 1
        else:
            k -= newIndex2 - index2 + 1
            index2 = newIndex2 + 1

if __name__ == '__main__':
    nums1 = [1,3,4,9]
    nums2 = [1,2,2,3,4,5,6,7,8,9,10]
    k = 11
    res = getKthElement(nums1, nums2, k)
    print(res)