原题
给定一个整数数组 nums,返回区间和在 [lower, upper] 之间的个数,包含 lower 和 upper。
区间和 S(i, j) 表示在 nums 中,位置从 i 到 j 的元素之和,包含 i 和 j (i ≤ j)。
参考
multiset
class Solution {
public:
int countRangeSum(vector<int>& nums, int lower, int upper) {
int n = nums.size();
multiset<long long> st;
st.insert(0);
long long sum = 0;
int ans = 0,l,r;
for(auto val:nums) {
sum += val;
ans += distance(st.lower_bound(sum-upper),st.upper_bound(sum-lower));
st.insert(sum);
}
return ans;
}
};
线段树+离散化
class Solution {
public:
int countRangeSum(vector<int>& nums, int lower, int upper) {
int n = nums.size(),pos,a,b;
int ans = 0;
vector<long long> h(n+1);
h[0] = 0;
for(int i = 1;i <= n;++i) {
h[i] = h[i-1]+nums[i-1];
}
sort(h.begin(),h.end());
h.erase(unique(h.begin(),h.end()),h.end());
n = h.size();
long long cur = 0;
sum.resize(4*(n+1),0);
pos = lower_bound(h.begin(),h.end(),cur) - h.begin() + 1;
update(1,1,n,pos,1);
for(auto val : nums) {
cur += val;
a = lower_bound(h.begin(),h.end(),cur-upper) - h.begin() + 1;
b = upper_bound(h.begin(),h.end(),cur-lower) - h.begin();
if(a <= b)
ans += query(1,1,n,a,b);
pos = lower_bound(h.begin(),h.end(),cur) - h.begin() + 1;
update(1,1,n,pos,1);
}
return ans;
}
vector<int> sum;
void pushup(int rt) {
sum[rt] = sum[rt<<1] + sum[rt<<1|1];
}
void update(int rt,int l,int r,int pos,int val) {
if(l == r) {
sum[rt] += val;
return;
}
int m = (l+r)>>1;
if(pos <= m)
update(rt<<1,l,m,pos,val);
else
update(rt<<1|1,m+1,r,pos,val);
pushup(rt);
}
int query(int rt,int l,int r,int a,int b) {
if(l >= a && r <= b)
return sum[rt];
int m = (l+r)>>1;
int ans = 0;
if(a <= m)
ans += query(rt<<1,l,m,a,b);
if(m < b)
ans += query(rt<<1|1,m+1,r,a,b);
return ans;
}
};
归并排序
常数最小的解法
class Solution {
public:
int ans;
int countRangeSum(vector<int>& nums, int lower, int upper) {
int n = nums.size();
ans = 0;
vector<long long> sum(n+1);
vector<long long> tmp(n+1);
sum[0] = 0;
for(int i = 1;i <= n;++i)
sum[i] = sum[i-1] + nums[i-1];
mergeSort(sum,tmp,0,n+1,lower,upper);
return ans;
}
void mergeSort(vector<long long> &sum,vector<long long> &tmp,int l,int r,int &lower,int &upper) {
if(l+1 >= r) return;
int m = (l+r)>>1;
mergeSort(sum,tmp,l,m,lower,upper);
mergeSort(sum,tmp,m,r,lower,upper);
merge(sum,tmp,l,m,r,lower,upper);
}
void merge(vector<long long> &sum,vector<long long> &tmp,int &l,int &m,int &r,int &lower,int &upper) {
int pos = l;
int pl = m,pr = m;
while(pos < m) {
while(pl < r && sum[pl] - sum[pos] < lower) ++pl;
while(pr < r && sum[pr] - sum[pos] <= upper) ++pr;
ans += pr - pl;
++pos;
}
pos = l;
pl = l;pr = m;
while(pl < m && pr < r) {
if(sum[pl] < sum[pr]) tmp[pos++] = sum[pl++];
else tmp[pos++] = sum[pr++];
}
while(pl < m) tmp[pos++] = sum[pl++];
while(pr < r) tmp[pos++] = sum[pr++];
for(int i = l;i < r;++i) sum[i] = tmp[i];
}
};