最开始我的想法是使用四个个指针分别指向 num1 和 num2 的开头和结尾,然后互相追逐,指向小元素的指针去追指向大元素的指针,一共移动(len(num1)+len(num2))/2 次以后,就可以得到中位数。
当然,这个算法的复杂度是 O(m+n),不符合题意。因为要 log(m+n) 的复杂度,所以我隐约觉得需要分治策略,一般用二分查找,但是已经想不出好的二分查找的方法了(果然是 Hard 难度)。
在看了 https://www.jianshu.com/p/9bd57fd52062 以后,根据他的算法写出了如下程序:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37
| func findMedian(nums1 []int, nums2 []int, k int) float64 { if len(nums1)>len(nums2) { return findMedian(nums2,nums1,k) } if len(nums1)==0 { return float64(nums2[k-1]) } if n1,n2:=nums1[0],nums2[0];k==1 { if n1<n2 { return float64(n1) }else{ return float64(n2) } } m:=k/2 if l1:=len(nums1);m>l1 { m=l1 } n:=k-m switch { case nums1[m-1]==nums2[n-1]: return float64(nums1[m-1]) case nums1[m-1]<nums2[n-1]: return findMedian(nums1[m:],nums2,k-m) case nums1[m-1]>nums2[n-1]: return findMedian(nums1,nums2[n:],k-n) } return 0.0 } func findMedianSortedArrays(nums1 []int, nums2 []int) float64 { if l:=len(nums1)+len(nums2);l%2==0 { return (findMedian(nums1,nums2,l/2)+findMedian(nums1,nums2,l/2+1))/2.0 }else{ return findMedian(nums1,nums2,l/2+1) } }
|
这道题在 solution 中有完整的解释,其中文翻译在知乎专栏中可以找到。
大意是使用了中位数的性质:找到一个划分,使得数组中该划分两边的数字数目相等,且左边的所有数字小于右边的所有数字。那么,在两个数组中,也同样可以使用该定律,即:在两个数组中分别找到两个划分点,使得两个划分点左边的数字的数目之和等于它们右边的数字数目之和,且划分点左边的所有数字小于它们右边的所有数字(划分点左边最大的数字小于划分点右边最小的数字)。
假设数组a中的划分点为i,数组b中的划分点为j,另外假设len(a) < len(b)
,由于数组是排好序的,那么显然可以得到j = (len(a)+len(b)+1)/2-i
。另外,由上面的推理以及数组是排序的可以得出只需满足a[i-1]<=b[j]
和b[j-1]<=a[i]
即可确定i和j分别为数组a和数组b的中位数划分点。
由于题目要求使用O(log(m+n))的复杂度,那么寻找划分点的过程就需要使用二分法了。首先,确定i的范围为imin=0/imax=len(a)
,并且取i=(imin+imax)/2=len(a)/2
。如果a[i-1]>b[j]
,那就说明i有些大,这时候就可以设置imax=i-1
;如果b[j-1]>a[i]
,这就说明i有些小,这时就可以设置imin=i+1
,以此类推,直到找到符合条件的i。
最后需要讨论的是几个边界条件。首先由数组条件可以得到的边界是i>=0,i<len(a),j>=0,j<len(a)
,如果到达边界,说明某个数组的划分已经在其边缘上了,只需在另一个数组中确定位置即可,不必再进行二分迭代的查找了,所以一旦遇到边界条件,可以直接跳出循环,得到答案。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34
| class Solution { public: double findMedianSortedArrays(vector<int>& nums1, vector<int>& nums2){ int m = nums1.size(); int n = nums2.size(); if (m<=n) return helper(nums1, nums2, m, n); else return helper(nums2, nums1, n, m); } double helper(vector<int>& nums1, vector<int>& nums2, int m, int n){ int i_min = 0; int i_max = m; while(i_min<=i_max){ int i = (i_min+i_max)/2; int j = (m+n+1)/2-i; if(i>0 && nums1[i-1]>nums2[j]){ i_max=i-1; }else if(i<m && nums2[j-1]>nums1[i]){ i_min=i+1; }else{ int left_max; if(i==0) left_max = nums2[j-1]; else if(j==0) left_max = nums1[i-1]; else left_max = max(nums1[i-1], nums2[j-1]); if((m+n)%2!=0) return (double)left_max; int right_min; if(i==m) right_min = nums2[j]; else if(j==n) right_min = nums1[i]; else right_min = min(nums1[i], nums2[j]); return (left_max+right_min)/2.0; } } return 0.0; } };
|