寻找两个正序数组的中位数
一、题目描述 (注意:不去重!!!)
给定两个大小为 m 和 n 的正序(从小到大)数组 n u m s 1 nums1 nums1 和 n u m s 2 nums2 nums2。
请你找出这两个正序数组的中位数,并且要求算法的时间复杂度为 O ( l o g ( m + n ) ) O(log(m + n)) O(log(m+n))。
你可以假设 nums1 和 nums2 不会同时为空。
示例 1:
nums1 = [1, 3]
nums2 = [2]
则中位数是 2.0
示例 2:
nums1 = [1,2,3]
nums2 = [1,2,3,4,5,6,7,8,9]
num = [1,1,2,2,3, 3,4, 5,6,7,8,9]
则中位数是 (3 + 4)/2 = 3.5
二、解题思路 & 代码
题目的要求 O ( l o g ( m + n ) O(log(m+n) O(log(m+n)。看到 l o g log log,很明显,我们只有用到二分的方法才能达到。
把问题转化为 在两个数组中找第 k k k 个大的数,求中位数就是求 第 k k k 小数的一种特殊情况(比如,两个数组总长度为 14,那么 k = ( 14 + 1 ) / / 2 = 7 k = (14+1) // 2 = 7 k=(14+1)//2=7)
假设我们要找第 k k k 小数,我们可以每次循环排除掉 k / 2 k/2 k/2 个数
A \bf A A: A [ 1 ] , A [ 2 ] , A [ 3 ] , A [ k / 2 ] . . . A[1] ,A[2] ,A[3],A[k/2] ... A[1],A[2],A[3],A[k/2]...
B \bf B B: B [ 1 ] , B [ 2 ] , B [ 3 ] , B [ k / 2 ] . . . B[1],B[2],B[3],B[k/2] ... B[1],B[2],B[3],B[k/2]...
如果 A [ k / 2 ] < B [ k / 2 ] A[k/2]<B[k/2] A[k/2]<B[k/2] ,那么 A [ 1 ] , A [ 2 ] , A [ 3 ] , A [ k / 2 ] A[1],A[2],A[3],A[k/2] A[1],A[2],A[3],A[k/2]都不可能是第 k k k 小的数字。
A A A 数组中比 A [ k / 2 ] A[k/2] A[k/2] 小的数有 k / 2 − 1 k/2-1 k/2−1 个, B B B 数组中, B [ k / 2 ] < A [ k / 2 ] B[k/2] < A[k/2] B[k/2]<A[k/2] ,假设 B [ k / 2 ] B[k/2] B[k/2] 前边的数字都比 A [ k / 2 ] A[k/2] A[k/2] 小,也只有 k / 2 − 1 k/2-1 k/2−1 个,所以比 A [ k / 2 ] A[k/2] A[k/2] 小的数字最多有 k / 2 − 1 + k / 2 − 1 = k − 2 k/2-1+k/2-1=k-2 k/2−1+k/2−1=k−2 个,所以 A [ k / 2 ] A[k/2] A[k/2] 最多是第 k − 1 k-1 k−1 小的数。而比 A [ k / 2 ] A[k/2] A[k/2] 小的数更不可能是第 k k k 小的数了,所以可以把它们排除。
class Solution:
def findMedianSortedArrays(self, nums1: List[int], nums2: List[int]) -> float:
n = len(nums1)
m = len(nums2)
left = (n + m + 1) // 2
right = (n + m +2) // 2
def getKth(nums1, start1, end1, nums2, start2, end2, k):
L1 = end1 - start1 + 1
L2 = end2 - start2 + 1
# 让 L1 长度小于 L2,这样如果有数组为空,则一定是 L1
if (L1 > L2): return getKth(nums2, start2, end2, nums1, start1, end1, k)
if (L1 == 0): return nums2[start2 + k - 1]
if (k == 1): return min(nums1[start1], nums2[start2])
i = start1 + min(L1, k // 2) - 1
j = start2 + min(L2, k // 2) - 1
if (nums1[i] > nums2[j]):
return getKth(nums1, start1, end1, nums2, j + 1, end2, k - (j - start2 + 1))
else:
return getKth(nums1, i+1, end1, nums2, start2, end2, k - (i - start1 + 1))
# 将偶数和奇数情况合并,如果是奇数,会求同样的 k
return (getKth(nums1, 0, n - 1, nums2, 0, m - 1, left) + \
getKth(nums1, 0, n - 1, nums2, 0, m - 1, right)) * 0.5
复杂度分析:
-
时间复杂度: O ( l o g ( m + n ) O(log(m+n) O(log(m+n)。每进行一次循环,我们就减少 k/2 个元素,所以时间复杂度是 O(log(k),而 k=(m+n)/2,所以最终的复杂也就是 O ( l o g ( m + n ) O(log(m+n) O(log(m+n)。
-
空间复杂度: O ( 1 ) O(1) O(1) 。虽然我们用到了递归,但是可以看到这个递归属于尾递归,所以编译器不需要不停地堆栈,所以空间复杂度为 O(1)
另外附上 LeetCode 官方题解:
class Solution:
def findMedianSortedArrays(self, nums1: List[int], nums2: List[int]) -> float:
def getKthElement(k):
""" - 主要思路:要找到第 k (k>1) 小的元素,那么就取 pivot1 = nums1[k/2-1] 和 pivot2 = nums2[k/2-1] 进行比较 - 这里的 "/" 表示整除 - nums1 中小于等于 pivot1 的元素有 nums1[0 .. k/2-2] 共计 k/2-1 个 - nums2 中小于等于 pivot2 的元素有 nums2[0 .. k/2-2] 共计 k/2-1 个 - 取 pivot = min(pivot1, pivot2),两个数组中小于等于 pivot 的元素共计不会超过 (k/2-1) + (k/2-1) <= k-2 个 - 这样 pivot 本身最大也只能是第 k-1 小的元素 - 如果 pivot = pivot1,那么 nums1[0 .. k/2-1] 都不可能是第 k 小的元素。把这些元素全部 "删除",剩下的作为新的 nums1 数组 - 如果 pivot = pivot2,那么 nums2[0 .. k/2-1] 都不可能是第 k 小的元素。把这些元素全部 "删除",剩下的作为新的 nums2 数组 - 由于我们 "删除" 了一些元素(这些元素都比第 k 小的元素要小),因此需要修改 k 的值,减去删除的数的个数 """
index1, index2 = 0, 0
while True:
# 特殊情况
if index1 == m:
return nums2[index2 + k - 1]
if index2 == n:
return nums1[index1 + k - 1]
if k == 1:
return min(nums1[index1], nums2[index2])
# 正常情况
newIndex1 = min(index1 + k // 2 - 1, m - 1)
newIndex2 = min(index2 + k // 2 - 1, n - 1)
pivot1, pivot2 = nums1[newIndex1], nums2[newIndex2]
if pivot1 <= pivot2:
k -= newIndex1 - index1 + 1
index1 = newIndex1 + 1
else:
k -= newIndex2 - index2 + 1
index2 = newIndex2 + 1
m, n = len(nums1), len(nums2)
totalLength = m + n
if totalLength % 2 == 1:
return getKthElement((totalLength + 1) // 2)
else:
return (getKthElement(totalLength // 2) + getKthElement(totalLength // 2 + 1)) / 2
参考:
==============================================================================================================================
两个有序数组找第k大
解题思路 & 代码
其实就是上一个问题的子问题
def getKthElement(nums1, nums2, k):
""" - 主要思路:要找到第 k (k>1) 小的元素,那么就取 pivot1 = nums1[k/2-1] 和 pivot2 = nums2[k/2-1] 进行比较 - 这里的 "/" 表示整除 - nums1 中小于等于 pivot1 的元素有 nums1[0 .. k/2-2] 共计 k/2-1 个 - nums2 中小于等于 pivot2 的元素有 nums2[0 .. k/2-2] 共计 k/2-1 个 - 取 pivot = min(pivot1, pivot2),两个数组中小于等于 pivot 的元素共计不会超过 (k/2-1) + (k/2-1) <= k-2 个 - 这样 pivot 本身最大也只能是第 k-1 小的元素 - 如果 pivot = pivot1,那么 nums1[0 .. k/2-1] 都不可能是第 k 小的元素。把这些元素全部 "删除",剩下的作为新的 nums1 数组 - 如果 pivot = pivot2,那么 nums2[0 .. k/2-1] 都不可能是第 k 小的元素。把这些元素全部 "删除",剩下的作为新的 nums2 数组 - 由于我们 "删除" 了一些元素(这些元素都比第 k 小的元素要小),因此需要修改 k 的值,减去删除的数的个数 """
m = len(nums1)
n = len(nums2)
index1, index2 = 0, 0
while True:
# 特殊情况
if index1 == m:
return nums2[index2 + k - 1]
if index2 == n:
return nums1[index1 + k - 1]
if k == 1:
return min(nums1[index1], nums2[index2])
# 正常情况
newIndex1 = min(index1 + k // 2 - 1, m - 1)
newIndex2 = min(index2 + k // 2 - 1, n - 1)
pivot1, pivot2 = nums1[newIndex1], nums2[newIndex2]
if pivot1 <= pivot2:
k -= newIndex1 - index1 + 1
index1 = newIndex1 + 1
else:
k -= newIndex2 - index2 + 1
index2 = newIndex2 + 1
if __name__ == '__main__':
nums1 = [1,3,4,9]
nums2 = [1,2,2,3,4,5,6,7,8,9,10]
k = 11
res = getKthElement(nums1, nums2, k)
print(res)