知识点

递归 分治

思路

这道题算上一题的强化版,三个有序数组,一个长度为n,一个长度为m,一个长度为p,找到三个数组归并后的中位数。

根据中位数的定义,假如m+n+p是奇数,那么需要找到第(n+m+p)/2+1小的数作为答案;假如是偶数,那么需要找到第(n + m+p) /2 小的数和(n+m+p)/2+1小的数的平均数作为中位数。

因此我们需要实现一个函数find可以在log时间内找到三个有序数组的第u小的数。

对于找两个有序数组的第u小数的函数find(vector<int>& nums1, int i, vector<int>& nums2, int j, vector<int>& nums3, int k, int u):

其中i是第一个数组第一个可以选的位置,j是第二个数组第一个可以选的位置,k是第三个数组第一个可以选的位置

首先,假如三个有序数组中的剩余未选的元素个数均大于等于u/2,那么我们先在三个数组的剩余元素中先各选u/2个, 比较出最小的位置, 那么该数组的u/2个元素被剔除, 可以把问题缩小为原来的一半

考虑边界情况,假如有一个数组不足够取u/2个元素了,那么把剩余的元素全取;u=1时选取三个数组中可以选择的最小元素即可

实现上始终让数组们的顺序按照备选元素个数递增用来简化代码逻辑。

时间复杂度

由于用分治思想,每次可以解决原问题的一半,时间复杂度为O(log(n+m+p))

#define x first
#define y second
class Solution {
public:
    /**
     * 代码中的类名、方法名、参数名已经指定,请勿修改,直接返回方法规定的值即可
     *
     * 
     * @param herd1 int整型vector 
     * @param herd2 int整型vector 
     * @param herd3 int整型vector 
     * @return double浮点型
     */
    double findMedianSortedArray(vector<int>& herd1, vector<int>& herd2, vector<int>& herd3) {
        int tot = herd1.size() + herd2.size() + herd3.size();
        if (!(tot & 1)) {
            int left = find(herd1, 0, herd2, 0, herd3, 0, tot / 2);
            int right = find(herd1, 0, herd2, 0, herd3, 0, tot / 2 + 1);
            return (double)(left + right) / 2.0;
        }
        return (double)find(herd1, 0, herd2, 0, herd3, 0, tot / 2 + 1);
    }
    int find(vector<int>& nums1, int i, vector<int>& nums2, int j, vector<int>& nums3, int k, int u) {
        if (nums1.size() - i > nums2.size() - j) swap(nums1, nums2), swap(i, j);
        if (nums1.size() - i > nums3.size() - k) swap(nums1, nums3), swap(i, k);
        if (nums2.size() - j > nums3.size() - k) swap(nums2, nums3), swap(j, k);

        if (u == 1) {
            if (nums1.size() == i and nums2.size() == j) return nums3[k];
            else if (nums1.size() == i) return min(nums2[j], nums3[k]);
            else return min(nums1[i], min(nums2[j], nums3[k]));
        }
        if (nums1.size() == i and nums2.size() == j) return nums3[k + u - 1];
        else if (nums1.size() == i) {
            int sj = min(j + u / 2, (int)nums2.size()), sk = k + u - u / 2;
            if (nums2[sj - 1] > nums3[sk - 1]) return find(nums1, i, nums2, j, nums3, sk, u - (sk - k));
            else return find(nums1, i, nums2, sj, nums3, k, u - (sj - j));
        }
        int si = min(i + u / 2, (int)nums1.size()), sj = j + u / 2, sk = k + u / 2;
        int mn = min({nums1[si - 1], nums2[sj - 1], nums3[sk - 1]});
        if (nums1[si - 1] == mn) return find(nums1, si, nums2, j, nums3, k, u - (si - i));
        else if (nums2[sj - 1] == mn) return find(nums1, i, nums2, sj, nums3, k, u - (sj - j));
        return find(nums1, i, nums2, j, nums3, sk, u - (sk - k));
    }
};