题目描述
解题思路
中位数需要根据两个数组长度和的奇偶决定:
- 假设
nums1.length = m,nums2.length = n
- 若
(m + n) % 2 == 0
,表示两数组长度之和为偶数,中位数则是中间两个数 - 否则为奇数,中位数是中间的数
但是我们可以不同通过分别考虑来计算最终值,只需要通过第
(m + n + 1) / 2
个数 + 第(m + n + 2) / 2
个数除以2就可以得到最终结果
1、不考虑复杂度
如果我们不考虑算法的复杂度,很容易想到,将两个数组拼成一个大数组,然后在大数组中求中位数就OK了,此时算法的复杂度是O(m+n)
这里只给出 Java
代码:
class Solution {
public double findMedianSortedArrays(int[] nums1, int[] nums2) {
int n1 = nums1.length;
int n2 = nums2.length;
int len = n1 + n2;
int[] merge = new int[len];
int i = 0, j = 0, k = 0;
//n1 = n2
while(i < n1 && j < n2){
if(nums1[i] < nums2[j]){
merge[k] = nums1[i];
i++;
k++;
}else{
merge[k] = nums2[j];
j++;
k++;
}
}
//n1 > n2
while(i < n1){
merge[k] = nums1[i];
i++;
k++;
}
//n1 < n2
while(j < n2){
merge[k] = nums2[j];
j++;
k++;
}
double res = 0.0;
//取中位数
if(len % 2 == 0){
res = (merge[len/2 - 1] + merge[len/2]) * 1.0 / 2;
}else{
res = merge[len/2];
}
return res;
}
}
2、双指针解法
(1)使两个指针分别从两个数组的开头位置开始,比较两个指针所指向的数的大小,谁小谁就往后移动;
(2)通过计算得到中位数的位置,得到两指针总共需要移动的次数 k
- 若
m + n
为偶数,应该移动(m + n) / 2 - 1
和(m + n) / 2
次 - 若
m + n
为奇数,应该移动(m + n) / 2
次
另外需要注意的是,边界条件的考虑,如果其中一个数组已经到最后,但是还没有达到k次,那么就让另一个指针继续往后移动
Java
代码:
class Solution {
public double findMedianSortedArrays(int[] nums1, int[] nums2) {
int n1 = nums1.length;
int n2 = nums2.length;
int len = n1 + n2;
int idx1 = 0;
int idx2 = 0;
int x = 0, y = 0;
while(idx1 + idx2 < len){
if(idx1 < n1){
while(idx2 == n2 || nums1[idx1] <= nums2[idx2]){
idx1++;
if(idx1 + idx2 == (len + 1) / 2){
x = nums1[idx1 - 1];
}
if(idx1 + idx2 == (len + 2) / 2){
y = nums1[idx1 - 1];
return (x + y) * 1.0 / 2;
}
if(idx1 == n1){
break;
}
}
}
if(idx2 < n2){
while(idx1 == n1 || nums2[idx2] <= nums1[idx1]){
idx2++;
if(idx1 + idx2 == (len + 1) / 2){
x = nums2[idx2 - 1];
}
if(idx1 + idx2 == (len + 2) / 2){
y = nums2[idx2 - 1];
return (x + y) * 1.0 / 2;
}
if(idx2 == n2){
break;
}
}
}
}
return -1;
}
}
3、二分法
通过以上分析,我们可以发现,其实我们就是需要找到第(m + n + 1) / 2
个数 + 第(m + n + 2) / 2
个数即可。
但是我们需要对各种情况进行讨论:
- 当一个数组比较短的时候,中位数都在另一个数组中,该如何处理
- 如何获得 第
(m + n + 1) / 2
个数和第(m + n + 2) / 2
个数
此外,需要注意的是,当一个数组太短,导致中位数在另一个数组中时,要考虑到边界判断
上图只查找了一个,另一个类似,或者以第一个为基础继续查找
Java
代码
class Solution {
public double findMedianSortedArrays(int[] nums1, int[] nums2) {
int n1 = nums1.length;
int n2 = nums2.length;
int len = n1 + n2;
int md1 = (len + 1) / 2;
int md2 = (len + 2) / 2;
return (getMdVal(nums1, 0, nums2, 0, md1) + getMdVal(nums1, 0, nums2, 0, md2)) * 1.0 / 2;
}
public static int getMdVal(int[] A, int Astart, int[] B, int Bstart, int k){
if (Astart > A.length-1){
return B[Bstart + k -1];
}
if (Bstart > B.length-1){
return A[Astart + k -1];
}
if (k == 1){
return Math.min(A[Astart],B[Bstart]);
}
int Amin = 0, Bmin = 0;
if (Astart + k/2 -1 < A.length){
Amin = A[Astart + k/2 -1];
}
if (Bstart + k/2 -1 < B.length){
Bmin = B[Bstart + k/2 -1];
}
return Amin < Bmin ? getMdVal(A, Astart + k/2, B, Bstart, k-k/2):getMdVal(A, Astart, B, Bstart+k/2, k-k/2);
}
}
Python3
代码
class Solution:
def findMedianSortedArrays(self, nums1: List[int], nums2: List[int]) -> float:
m, n = len(nums1), len(nums2)
length = m + n
md1 = (length + 1) // 2
md2 = (length + 2) // 2
def getMdVal(A, Astart, B, Bstart, k):
if Astart > (len(A) - 1):
return B[int(Bstart + k - 1)]
if Bstart > (len(B) - 1):
return A[int(Astart + k - 1)]
if k == 1:
return A[int(Astart)] if (A[int(Astart)] < B[int(Bstart)]) else B[int(Bstart)]
Amin, Bmin = 2**31-1, 2**31-1
if (Astart + k//2 - 1) < len(A):
Amin = A[int(Astart + k//2 - 1)]
if (Bstart + k//2 - 1) < len(B):
Bmin = B[int(Bstart + k//2 - 1)]
return getMdVal(A, Astart + k//2, B, Bstart, k-k//2) if Amin < Bmin else getMdVal(A, Astart, B, Bstart+k//2, k-k//2)
return (getMdVal(nums1, 0, nums2, 0, md1) + getMdVal(nums1, 0, nums2, 0, md2)) / 2
运行结果不怎么好
class Solution:
def findMedianSortedArrays(self, nums1, nums2):
nums = nums1 + nums2
nums.sort()
length = len(nums)
if length == 2:
return (nums[0] + nums[1])/2
if length % 2 == 0:
return (nums[length // 2 - 1] + nums[(length // 2)])/2
return nums[length // 2]
此代码是先拼接,在查找的,貌似两个结果差不多
C++
代码
class Solution {
public:
double findMedianSortedArrays(vector<int>& nums1, vector<int>& nums2) {
int m = nums1.size();
int n = nums2.size();
int md1 = (m + n + 1) / 2;
int md2 = (m + n + 2) / 2;
return (getMdVal(nums1, 0, nums2, 0, md1) + getMdVal(nums1, 0, nums2, 0, md2)) / 2.0;
}
int getMdVal(vector<int>& A, int Astart, vector<int>& B, int Bstart, int k) {
if (Astart >= A.size()) return B[Bstart + k - 1];
if (Bstart >= B.size()) return A[Astart + k - 1];
if (k == 1) return min(A[Astart], B[Bstart]);
int midVal1 = (Astart + k / 2 - 1 < A.size()) ? A[Astart + k / 2 - 1] : INT_MAX;
int midVal2 = (Bstart + k / 2 - 1 < B.size()) ? B[Bstart + k / 2 - 1] : INT_MAX;
if (midVal1 < midVal2) {
return getMdVal(A, Astart + k / 2, B, Bstart, k - k / 2);
} else {
return getMdVal(A, Astart, B, Bstart + k / 2, k - k / 2);
}
}
};