第一想法当然是用矩阵,矩阵元素Sum[i][j]表示从i到j符合条件的组合个数
对于例子-2, 5, -1
,Sum矩阵就是:
-2 3 2
5 4
-1
这样只要找到大于等于-2,小于等于2的元素就行。我以为这样做的时间复杂度会小于O(n^2),但是其实这样做的时间复杂度是O(n^2/2),其实还是O(n^2),囧
优化的方法是把要求解的问题转化为另外一种形式,然后用分治法求解,这样能把时间复杂度降低为O(nlogn)
问题转换为:
对于一个已经从小到大排好顺序的数组(这一点很重要,我在后面会解释),所有在[lower,upper]之间的组合个数 = 小于等于upper的组合个数 - 小于lower的组合个数
代码中体现为
for(int i=start, s=0; i<mid; i++, s++){
/*** wrong code: while(m<end && sums[m++]-sums[i]<lower); ***/
while(m<end && sums[m]-sums[i]<lower) m++;
while(n<end && sums[n]-sums[i]<=upper) n++;
我们希望while循环遍历到所有的符合条件的组合,但是因为while循环的条件是
sums[m]-sums[i]<lower
所以如果数组不是排好序的话,跳出while后可能还有组合情况没遍历到
为了确保sums数组是排好序的,采用归并排序的方法。
以前没深入做过归并排序,从伪代码到正常的代码,合并过程都是一句swap(),这份答案中采用了cache数组做记录,然后覆盖原数组的方法,两层循环加迭代,看的我眼花缭乱
class Solution {
public:
int countRangeSum(vector<int>& nums, int lower, int upper) {
int size=nums.size();
if(size==0) return 0;
vector<long> sums(size+1, 0);
for(int i=0; i<size; i++) sums[i+1]=sums[i]+nums[i];
return help(sums, 0, size+1, lower, upper);
}
/*** [start, end) ***/
int help(vector<long>& sums, int start, int end, int lower, int upper){
/*** only-one-element, so the count-pair=0 ***/
if(end-start<=1) return 0;
int mid=(start+end)/2;
int count=help(sums, start, mid, lower, upper)
+ help(sums, mid, end, lower, upper);
int m=mid, n=mid, t=mid, len=0;
/*** cache stores the sorted-merged-2-list ***/
/*** so we use the "len" to record the merged length ***/
vector<long> cache(end-start, 0);
for(int i=start, s=0; i<mid; i++, s++){
/*** wrong code: while(m<end && sums[m++]-sums[i]<lower); ***/
while(m<end && sums[m]-sums[i]<lower) m++;
while(n<end && sums[n]-sums[i]<=upper) n++;
count+=n-m;
/*** cache will merge-in-the-smaller-part-of-list2 ***/
//下面的cache一部分在while循环里,一部分在while循环外,两部分一起完成了缓存工作,并在之后覆盖原来的sums数组
while(t<end && sums[t]<sums[i]) cache[s++]=sums[t++];
cache[s]=sums[i];
len=s;
}
for(int i=0; i<=len; i++) sums[start+i]=cache[i];
return count;
}
};