Description:
Given an integer array nums
, return the number of range sums that lie in [lower, upper]
inclusive.
Range sum S(i, j)
is defined as the sum of the elements in nums
between indices i
and j
(i
≤ j
), inclusive.
Note:
A naive algorithm of O(n2) is trivial. You MUST do better than that.
Example:
Given nums = [-2, 5, -1]
, lower = -2
, upper = 2
,
Return 3
.
The three ranges are : [0, 0]
, [2, 2]
, [0, 2]
and their respective sums are: -2, -1, 2
.
Solution:
线段转换成点的做法,用线段树或者树状数组(Binary Index Tree / Fenwick Tree)可以解决。
基本想法是这样:
首先把所以的前n项和存到sums[]数组中,long[]。
然后对这些sums进行哈希,因为后面可以用TreeMap的二分查找,所以直接用TreeMap,当然也可以选择自己写二分+HashSet进行哈希。
然后实例化线段树或者树状数组。
最后对sums进行遍历,遍历到i的时候,我们希望找到的是在[0,i-1]中的某个j,符合sum[i] - sums[j-1]在[lower, upper]之间。也就是sums[j-1]要在[sums[i] - upper, sums[i]-lower]之间。这里就可以利用我们哈希之后的结果来做,将问题转换成[map.ceilingKey(sums[i]-upper), map.floorKey(sums[i]-lower)],而map的上下取整查找key都是二分,所以复杂度是O(nlogn)。
注意1:如果sums[i]本身也在[lower, upper]范围之内,也需要ans+1
注意2:因为是求[map.ceilingKey(sums[i]-upper), map.floorKey(sums[i]-lower)]这个区间内的个数,所以对于树状数组来说,需要 map.ceilingKey(sums[i]-upper)-1 --- map.floorKey(sums[i]-lower)这个两个数字之间的差值。但是对于线段树,就可以直接求解这个区间内的个数和。
Code 1: 树状数组(Binary Index Tree / Fenwick Tree)
<span style="font-size:18px;">import java.util.*;
public class Solution {
public int countRangeSum(int[] nums, int lower, int upper) {
int n = nums.length;
long sums[] = new long[n];
TreeMap<Long, Integer> map = new TreeMap<Long, Integer>();
for (int i = 0; i < n; i++) {
if (i == 0)
sums[i] = nums[i];
else
sums[i] = sums[i - 1] + nums[i];
map.put(sums[i], 0);
}
map.put(Long.MAX_VALUE, 0);
map.put(Long.MIN_VALUE, 0);
int tot = 1;
for (Iterator<Long> ite = map.keySet().iterator(); ite.hasNext();) {
long key = ite.next();
map.put(key, tot++);
}
FenwickTree tree = new FenwickTree(n + 2);
int ans = 0;
long right, left;
for (int i = 0; i < n; i++) {
long s = sums[i];
left = map.ceilingKey(s - upper);
right = map.floorKey(s - lower);
ans += tree.sum(map.get(right)) - tree.sum(map.get(left) - 1);
if (lower <= s && s <= upper)
ans++;
tree.insert(map.get(s), 1);
}
return ans;
}
class FenwickTree {
int n;
int[] c;
FenwickTree(int n) {
this.n = n;
c = new int[n + 1];
}
public int lowbit(int x) {
return x & -x;
}
public void insert(int x, int dif) {
while (x <= n) {
c[x] += dif;
x += lowbit(x);
}
}
public int sum(int x) {
int ans = 0;
while (x > 0) {
ans += c[x];
x -= lowbit(x);
}
return ans;
}
}
public static void main(String[] args) {
Solution s = new Solution();
// int arr[] = { Integer.MAX_VALUE, Integer.MIN_VALUE, -1, 0 };
// System.out.println(s.countRangeSum(arr, -1, 0));
int arr[] = { -2, 5, -1 };
System.out.println(s.countRangeSum(arr, -2, 2));
}
}</span>
import java.util.*;
public class Solution {
public int countRangeSum(int[] nums, int lower, int upper) {
int n = nums.length;
long sums[] = new long[n];
TreeMap<Long, Integer> map = new TreeMap<Long, Integer>();
for (int i = 0; i < n; i++) {
if (i == 0)
sums[i] = nums[i];
else
sums[i] = sums[i - 1] + nums[i];
map.put(sums[i], 0);
}
map.put(Long.MAX_VALUE, 0);
map.put(Long.MIN_VALUE, 0);
int tot = 1;
for (Iterator<Long> ite = map.keySet().iterator(); ite.hasNext();) {
long key = ite.next();
map.put(key, tot++);
}
SegmentTree st = new SegmentTree(n + 2);
st.query(1, 2, 1);
int ans = 0;
long right, left;
for (int i = 0; i < n; i++) {
long s = sums[i];
left = map.ceilingKey(s - upper);
right = map.floorKey(s - lower);
if(right>=left)
ans += st.query(map.get(left), map.get(right), 1);
if (lower <= s && s <= upper)
ans++;
st.update(map.get(s), 1, 1);
}
return ans;
}
class SegmentTree {
int[] lazy;
int n;
node[] nodes;
SegmentTree(int n) {
this.n = n;
lazy = new int[n + 1];
nodes = new node[n * 5];
build(1, n, 1);
}
public int build(int left, int right, int idx) {
nodes[idx] = new node();
nodes[idx].left = left;
nodes[idx].right = right;
if (left == right)
return nodes[idx].lazy = lazy[left];
int mid = (left + right) >> 1;
return nodes[idx].lazy = build(left, mid, idx << 1)
+ build(mid + 1, right, idx << 1 | 1);
}
public void update(int key, int x, int idx) {
nodes[idx].lazy += x;
if (nodes[idx].left == nodes[idx].right)
return;
int mid = nodes[idx].calmid();
if (key <= mid)
update(key, x, idx << 1);
else
update(key, x, idx << 1 | 1);
}
public int query(int left, int right, int idx) {
if (left == nodes[idx].left && right == nodes[idx].right)
return nodes[idx].lazy;
int mid = nodes[idx].calmid();
if (mid >= right)
return query(left, right, idx << 1);
if (mid + 1 <= left)
return query(left, right, idx << 1 | 1);
return query(left, mid, idx << 1)
+ query(mid + 1, right, idx << 1 | 1);
}
class node {
int left, right, lazy;
int calmid() {
return (left + right) >> 1;
}
}
}
public static void main(String[] args) {
Solution s = new Solution();
int arr[] = { 0, -1, -2, -3, 0, 2 };
System.out.println(s.countRangeSum(arr, 3, 5));
}
}