- Range Sum Query - Mutable
中文English
Given an integer array nums, and then you need to implement two functions:
update(i, val) Modify the element whose index is i to val.
sumRange(l, r) Return the sum of elements whose indexes are in range of [l, r][l,r].
Example
Example 1:
Input:
nums = [1, 3, 5]
sumRange(0, 2)
update(1, 2)
sumRange(0, 2)
Output:
9
8
Example 2:
Input:
nums = [0, 9, 5, 7, 3]
sumRange(4, 4)
sumRange(2, 4)
update(4, 5)
update(1, 7)
update(0, 8)
sumRange(1, 2)
Output:
3
15
12
Notice
The array is only modifiable by the update function.
You may assume the number of calls to update and sumRange function is distributed evenly.
解法1:树状数组 binary index tree
注意:
- 涉及到C数组的add()和sum()里面的循环都必须从x开始到n或0,而不能从n或0到x。因为C数组的循环是跳着前进的,步长为lowbit(x),如果从n或0开始,不一定能跳到x。切记!
- 涉及到C数组的add()和sum()里面的x都必须+1,因为binary index tree里面的下标从1开始。
代码如下:
class NumArray {
public:
NumArray(vector<int> nums) {
len = nums.size();
A = nums;
C.resize(len + 1, 0);
for (int i = 0; i < len; ++i) {
add(i, A[i]);
}
}
void update(int i, int val) {
add(i, val - A[i]);
A[i] = val;
}
int sumRange(int i, int j) {
return sum(j) - sum(i - 1);
}
private:
vector<int> C, A;
int len;
int lowbit(int x) {
return x & (-x);
}
void add(int x, int val) {
x++;
for (int i = x; i <= len; i += lowbit(i)) {
//for (int i = len; i >= x; i -= lowbit(i)) {
C[i] += val;
}
}
int sum(int x) {
x++;
int result = 0;
// for (int i = 1; i <= x; i += lowbit(i)) {
for (int i = x; i > 0; i -= lowbit(i)) {
result += C[i];
}
return result;
}
};
/**
* Your NumArray object will be instantiated and called as such:
* NumArray obj = new NumArray(nums);
* obj.update(i,val);
* int param_2 = obj.sumRange(i,j);
*/
解法2: 线段树
class STreeNode {
public:
STreeNode *left, *right;
STreeNode(int start, int end, int sum) : start(start), end(end), sum(sum), left(NULL), right(NULL) {}
friend class STree;
private:
int start, end, sum;
};
class STree{
public:
STree(vector<int> &nums) : nums(nums) {
root = build(0, nums.size() - 1);
}
STreeNode *getRoot() {return root;}
STreeNode *build(int start, int end);
int querySTree(STreeNode *root, int start, int end);
void modifySTree(STreeNode *root, int index, int value);
private:
vector<int> nums;
STreeNode *root;
};
STreeNode *STree::build(int start, int end) {
if (start > end) return NULL;
STreeNode *root = new STreeNode(start, end, 0);
if (start != end) {
int mid = start + (end - start) / 2;
root->left = build(start, mid);
root->right = build(mid + 1, end);
root->sum = root->left->sum + root->right->sum;
} else {
root->sum = nums[start];
}
return root;
}
int STree::querySTree(STreeNode *root, int start, int end) {
if (start <= root->start && root->end <= end) {
return root->sum;
}
int mid = root->start + (root->end - root->start) / 2;
int sum = 0;
if (start <= mid) {
sum += querySTree(root->left, start, end);
}
if (mid + 1 <= end) {
sum += querySTree(root->right, start, end);
}
return sum;
}
void STree::modifySTree(STreeNode *root, int index, int value) {
if (root->start == root->end && root->start == index) {
root->sum = value;
return;
}
int mid = root->start + (root->end - root->start) / 2;
if (index <= mid) {
modifySTree(root->left, index, value);
}
else {
modifySTree(root->right, index, value);
}
root->sum = root->left->sum + root->right->sum;
return;
}
class NumArray {
public:
NumArray(vector<int> nums) {
st = new STree(nums);
}
void update(int i, int val) {
auto root = st->getRoot();
st->modifySTree(root, i, val);
}
int sumRange(int i, int j) {
auto root = st->getRoot();
return st->querySTree(root, i, j);
}
private:
STree *st;
};
/**
* Your NumArray object will be instantiated and called as such:
* NumArray obj = new NumArray(nums);
* obj.update(i,val);
* int param_2 = obj.sumRange(i,j);
**/