题目:
Given an integer array nums
, return the number of range sums that lie in [lower, upper]
inclusive.
Range sum S(i, j)
is defined as the sum of the elements in nums
between indices i
and j
(i
? j
), inclusive.
Note:
A naive algorithm of O(n2) is trivial. You MUST do better than that.
Example:
Given nums = [-2, 5, -1]
, lower = -2
, upper = 2
,
Return 3
.
The three ranges are : [0, 0]
, [2, 2]
, [0, 2]
and their respective sums are: -2, -1, 2
.
思路:
1、二分搜索法:我们还记得累计求和的算法吧?如果我们用sum[i]表示数组中前i个元素的和,那么求任意区间[i, j]的和就可以通过sum[j + 1] - sum[i]在O(1)的时间内计算得到。在这道题目中我们就利用这一点。不过我们在这里不是把累计和放在普通数组vector中,而是放在一个有序数组(例如multiset)中(便于快速定界)。每次计算得到一个sum[i + 1]表示区间[0,i]的累计和,我们就将其插入到有序数组中,这样有序数组中就保存了从0开始到各个索引的区间和。假设我们现在遍历数组到了索引i,此时我们首先得到sum[i + 1]表示[0,i]的区间和,那么我们就需要检查一下哪些起始索引到i的区间的和满足lower <= sum[i + 1] - x <= upper,推导可得sum[i + 1] - upper <= x <= sum[i + 1] - lower。也就是需要求出0到哪些区间的累积和满足上面不等式。因此实际上在扫描第i个元素的时候,我们就可以得到从某个起始位置到i,其累积和在区间[lower, upper]的个数。当然在遍历完成之后,别忘了将当前的累计和加入累计数组中,以便于后续遍历的正确计算。对于二分查找树而言,计算上界和下界的时间复杂度都是O(nlogn),所以该算法的时间复杂度是O(nlogn),而空间复杂度则是O(n)。
2、分治法:在累积和的基础上,我们计算有多少个区间的大小落在[lower, upper]之间,一个朴素的算法就是枚举各个区间,其时间复杂度是O(n^2)。一个更好的方法是利用分治法来处理,即利用归并排序算法将数组分成左右两边,在合并左右数组之前,对于左边数组中的每一个元素,在右边数组找到一个范围,使得在这个范围中的元素与左边元素构成的区间和落在[lower, upper]之间,即在右边数组中找到两个边界,设为m,n,其中m是在右边数组中第一个使得sum[m] - sum[i] >= lower的位置,而n是第一个使得sum[n] - sum[i] > upper的位置,这样n-m就是与左边元素i所构成的位于[lower, upper]范围的区间个数。因为左右两边都是已经有序的,这样就可以避免不必要的比较(这也是为什么我们能将时间复杂度从O(n^2)降低到O(nlogn)的秘诀所在)。
代码:
1、插入排序法:
class Solution {
public:
int countRangeSum(vector<int>& nums, int lower, int upper) {
int res = 0;
long long sum = 0;
multiset<long long> sums; // this multiset is sorted automatically
sums.insert(0);
for(int i = 0; i < nums.size(); ++i) {
sum += nums[i];
res += distance(sums.lower_bound(sum - upper), sums.upper_bound(sum - lower));
sums.insert(sum);
}
return res;
}
};
2、分治法:
class Solution {
public:
int countRangeSum(vector<int>& nums, int lower, int upper) {
int len = nums.size();
vector<long> sum(len + 1, 0);
for(int i = 0; i < len; ++i) {
sum[i + 1] = sum[i] + nums[i];
}
return mergeSort(sum, lower, upper, 0, len + 1);
}
private:
int mergeSort(vector<long>& sum, int lower, int upper, int low, int high) {
if(high - low <= 1) {
return 0;
}
int mid = (low + high) / 2;
int m = mid, n = mid, count = 0;
count = mergeSort(sum, lower, upper, low, mid) + mergeSort(sum, lower, upper, mid, high);
for(int i = low; i < mid; ++i) { // for each start part in [low, mid), we get the range in [mid, high)
while(m < high && sum[m] - sum[i] < lower) ++m;
while(n < high && sum[n] - sum[i] <= upper) ++n;
count += (n - m);
}
inplace_merge(sum.begin() + low, sum.begin() + mid, sum.begin() + high);
return count;
}
};