给你一个数组 nums ,请你完成两类查询。
- 其中一类查询要求 更新 数组 nums 下标对应的值
- 另一类查询要求返回数组 nums 中索引 left 和索引 right 之间( 包含 )的 nums 元素的和 ,其中 left <= right
实现 NumArray 类:
- NumArray(int[] nums) 用整数数组 nums 初始化对象
- void update(int index, int val) 将 nums[index] 的值 更新 为 val
- int sumRange(int left, int right) 返回数组 nums 中索引 left 和索引 right 之间( 包含 )的 nums 元素的 和 (即,nums[left] + nums[left + 1], …, nums[right])
示例 1:
输入:
["NumArray", "sumRange", "update", "sumRange"]
[[[1, 3, 5]], [0, 2], [1, 2], [0, 2]]
输出:
[null, 9, null, 8]
解释:
NumArray numArray = new NumArray([1, 3, 5]);
numArray.sumRange(0, 2); // 返回 1 + 3 + 5 = 9
numArray.update(1, 2); // nums = [1,2,5]
numArray.sumRange(0, 2); // 返回 1 + 2 + 5 = 8
提示:
- 1 <= nums.length <= 3 * 104
- -100 <= nums[i] <= 100
- 0 <= index < nums.length
- -100 <= val <= 100
- 0 <= left <= right < nums.length
- 调用 update 和 sumRange 方法次数不大于 3 * 104
思路:利用线段树求解,线段树 segmentTree 是一个二叉树,每个结点保存数组 nums 在区间 [l, r] 的最小值、最大值或者总和等信息。
题目中要求返回区间元素的和,所以这里保存的是区间元素总和的信息;遇到求区间最大值、最小值等问题,就保存区间的最大值、最小值。
int n = nums.size();
从根节点开始,根节点保存了整个区间 [0, n - 1] 的总和,然后开始二分,左子节点保存 [0, (n- 1) / 2] 的总和, 右子节点保存 [(n - 1) / 2 + 1,n - 1] 的总和。一直往下二分,知道每个节点只代表一个值,此时不能再继续二分,该节点为叶子节点。
用数组保存二叉树,当某个节点的下标为 node 时, 左子节点的下标为 node * 2 + 1, 右子节点的下标为 node * 2 + 2;
- build函数:递归构造线段树;
- change函数:修改数组某个下标的值(线段树叶子节点的某个值被修改),同时要从下往上更新线段树中节点的值。
- range函数:返回某个区间的元素的和。因为线段树是按照二分的方式划分的,所以并不一定刚好对应要查询的区间;需要做一些处理
class NumArray {
vector<int> segmentTree;
int n;
void build(vector<int>& nums, int node, int left, int right) {
if (left == right) {
segmentTree[node] = nums[left];
return;
}
int mid = (left + right) / 2;
build(nums, node * 2 + 1, left , mid);
build(nums, node * 2 + 2, mid + 1, right);
segmentTree[node] = segmentTree[node * 2 + 1] + segmentTree[node * 2 + 2];
}
void change(int index, int val, int node, int left, int right) {
if (left == right) {
segmentTree[node] = val;
return;
}
int mid = (left + right) / 2;
if (index <= mid) {
change(index, val, node * 2 + 1, left, mid);
} else {
change(index, val, node * 2 + 2, mid + 1, right);
}
segmentTree[node] = segmentTree[node * 2 + 1] + segmentTree[node * 2 + 2];
}
int range(int left, int right, int node, int l, int r) {
if (left == l && right == r) {
return segmentTree[node];
}
int mid = (l + r) / 2;
if (right <= mid) {
return range(left, right, node * 2 + 1, l, mid);
} else if (left > mid){
return range(left, right, node * 2 + 2, mid + 1, r);
} else {
return range(left, mid, node * 2 + 1, l, mid) + range(mid + 1, right, node * 2 + 2, mid + 1, r);
}
}
public:
NumArray(vector<int>& nums) : n(nums.size()), segmentTree(nums.size() * 4) {
build(nums, 0, 0, n - 1);
}
void update(int index, int val) {
change(index, val, 0, 0, n - 1);
}
int sumRange(int left, int right) {
return range(left, right, 0, 0, n - 1);
}
};
n = nums.size();
- segmentTree 大小初始化为 n * 4;线段树是平衡二叉树,叶子节点的数量是 n, 二叉树的最大深度为 ⌈ log ( n ) ⌉ \lceil\log(n)\rceil ⌈log(n)⌉ + 1, 二叉树节点数量不超过 2 ⌈ log ( n ) ⌉ + 1 ) − 1 ≤ 4 n 2^{\lceil\log(n)\rceil + 1)} - 1 \leq 4n 2⌈log(n)⌉+1)−1≤4n