背景:
在上一篇线段树入门中涉及到对线段树的更新操作是单点操作,即更新的是一个,如果现在把问题升级下,让你将数组nums的某个区间[start,end]的值都加上value,该怎么做好呢?比较容易想到的是可以复用之前的更新操作,外面给它套个for循环,依次更新,这样虽然结果正确,但时间复杂度不堪入目…很明显是不可取的,基于这个问题,从而有了lazy思想,下面开始介绍。
概述:
lazy,顾名思义,就是懒惰的意思,在本问题中所体现的则是在进行区间更新时,延迟更新,用一个lazy标记记录下就好,没必要依次更新它的子节点,等到正真用到了的时候,为保证结果的准确性,需要将lazy标记传递下去即执行下pushDown操作,这样便能大大提高了效率。
实现:
首先定义了一个TreeNode类,用于封装一些信息,如下,
private static class TreeNode {
int left, right;
int sum, lazy;
public TreeNode(int left, int right) {
this.left = left;
this.right = right;
}
public int mid() {
return left + (right - left) / 2;
}
public int length() {
return right - left + 1;
}
}
并且,在该类中还定义了取该节点区间中点的方法和取该节点区间的长度的方法,这里把它定义为静态内部类,静态内部类与成员内部类的区别是:静态内部类在实例化的时候不需要有外部类的对象来实例化,而成员内部类实例化的时候是需要的。
接着就是线段树的构造,
public void build(int index, int start, int end, int[] nums) {
tree[index] = new TreeNode(start, end);
if (start == end) {
tree[index].sum = nums[start];
return;
}
int mid = tree[index].mid();
build(2 * index + 1, start, mid, nums);
build(2 * index + 2, mid + 1, end, nums);
pushUp(index);
}
与之前的差别不大,这里的pushUp方法如下,
public void pushUp(int index) {
tree[index].sum = tree[2 * index + 1].sum + tree[2 * index + 2].sum;
}
其实就是更新和sum,
然后来看下pushDown方法,
public void pushDown(int index) {
if (tree[index].lazy != 0) {
int lazy = tree[index].lazy;
int left = 2 * index + 1;
int right = 2 * index + 2;
tree[left].lazy += lazy;
tree[right].lazy += lazy;
tree[left].sum += lazy * tree[left].length();
tree[right].sum += lazy * tree[right].length();
tree[index].lazy = 0;
}
}
该方法的语义是,将当前节点的lazy标记向子节点传,如果lazy不为0,即存在标记,就往下更新,更新完之后当前节点的lazy标记需要复原,即重新置为0。
更新操作,
public void update(int index, int start, int end, int value) {
if (tree[index].left == start && tree[index].right == end) {
tree[index].lazy += value;
tree[index].sum += value * (tree[index].length());
return; //当前节点区间吻合时,更新该节点,其子节点就不急着更新了,直接返回
}
if (tree[index].left == tree[index].right) {
return;
}
// 每次要更新之前,我先pushDown下,保证结果的正确性
pushDown(index);
int mid = tree[index].mid(), left = 2 * index + 1, right = 2 * index + 2;
if (mid >= end) {
update(left, start, end, value);
} else if (mid < start) {
update(right, start, end, value);
} else {
update(left, start, mid, value);
update(right, mid + 1, end, value);
}
pushUp(index);
}
查询方法,
public int search(int index, int start, int end) {
if (tree[index].left == start && tree[index].right == end) {
return tree[index].sum;
}
// 同样的,每次查询之前,看看当前节点的lazy是否标记过
pushDown(index);
int mid = tree[index].mid(), left = 2 * index + 1, right = 2 * index + 2;
int sum = 0;
if (mid >= end) {
sum += search(left, start, end);
} else if (mid < start) {
sum += search(right, start, end);
} else {
sum += search(left, start, mid);
sum += search(right, mid + 1, end);
}
return sum;
}
上面的更新与查询操作分支有点多,其实也可以再小小的优化下。
优化:
优化后的更新操作,
public void update1(int index, int start, int end, int value) {
if (tree[index].left >= start && tree[index].right <= end) {
tree[index].lazy += value;
tree[index].sum += value * tree[index].length();
return;
}
pushDown(index);
int mid = tree[index].mid();
if (mid >= start) {
update1(2 * index + 1, start, end, value);
}
if (mid < end) {
update1(2 * index + 2, start, end, value);
}
pushUp(index);
}
优化后的查询操作,
public int search1(int index, int start, int end) {
if (tree[index].left >= start && tree[index].right <= end) {
return tree[index].sum;
}
pushDown(index);
int sum = 0, mid = tree[index].mid();
if (mid >= start) {
sum += search1(2 * index + 1, start, end);
}
if (mid < end) {
sum += search1(2 * index + 2, start, end);
}
return sum;
}
读者可以自己思考下,这两种实现方式为什么是等价的,这样有助于更深入的掌握线段树的实现。
最后完整代码如下,
public class SegmentTree {
TreeNode[] tree;
private static class TreeNode {
int left, right;
int sum, lazy;
public TreeNode(int left, int right) {
this.left = left;
this.right = right;
}
public int mid() {
return left + (right - left) / 2;
}
public int length() {
return right - left + 1;
}
}
public SegmentTree(int[] nums) {
tree = new TreeNode[4 * nums.length];
build(0, 0, nums.length - 1, nums);
}
public void build(int index, int start, int end, int[] nums) {
tree[index] = new TreeNode(start, end);
if (start == end) {
tree[index].sum = nums[start];
return;
}
int mid = tree[index].mid();
build(2 * index + 1, start, mid, nums);
build(2 * index + 2, mid + 1, end, nums);
pushUp(index);
}
public int search(int index, int start, int end) {
if (tree[index].left == start && tree[index].right == end) {
return tree[index].sum;
}
pushDown(index);
int mid = tree[index].mid(), left = 2 * index + 1, right = 2 * index + 2;
int sum = 0;
if (mid >= end) {
sum += search(left, start, end);
} else if (mid < start) {
sum += search(right, start, end);
} else {
sum += search(left, start, mid);
sum += search(right, mid + 1, end);
}
return sum;
}
public int search1(int index, int start, int end) {
if (tree[index].left >= start && tree[index].right <= end) {
return tree[index].sum;
}
pushDown(index);
int sum = 0, mid = tree[index].mid();
if (mid >= start) {
sum += search1(2 * index + 1, start, end);
}
if (mid < end) {
sum += search1(2 * index + 2, start, end);
}
return sum;
}
public void update(int index, int start, int end, int value) {
if (tree[index].left == start && tree[index].right == end) {
tree[index].lazy += value;
tree[index].sum += value * (tree[index].length());
return;
}
if (tree[index].left == tree[index].right) {
return;
}
pushDown(index);
int mid = tree[index].mid(), left = 2 * index + 1, right = 2 * index + 2;
if (mid >= end) {
update(left, start, end, value);
} else if (mid < start) {
update(right, start, end, value);
} else {
update(left, start, mid, value);
update(right, mid + 1, end, value);
}
pushUp(index);
}
public void update1(int index, int start, int end, int value) {
if (tree[index].left >= start && tree[index].right <= end) {
tree[index].lazy += value;
tree[index].sum += value * tree[index].length();
return;
}
pushDown(index);
int mid = tree[index].mid();
if (mid >= start) {
update1(2 * index + 1, start, end, value);
}
if (mid < end) {
update1(2 * index + 2, start, end, value);
}
pushUp(index);
}
public void pushDown(int index) {
if (tree[index].lazy != 0) {
int lazy = tree[index].lazy;
int left = 2 * index + 1;
int right = 2 * index + 2;
tree[left].lazy += lazy;
tree[right].lazy += lazy;
tree[left].sum += lazy * tree[left].length();
tree[right].sum += lazy * tree[right].length();
tree[index].lazy = 0;
}
}
public void pushUp(int index) {
tree[index].sum = tree[2 * index + 1].sum + tree[2 * index + 2].sum;
}
}