线段树(Segment Tree)
0.
给定一数组
(1)计算区间和 – query,(2)修改数组中的某一个值 – update
方法一:遍历
时间复杂度
如果query与update的次数都很多的话,O(n)的时间复杂度会导致非常非常慢
方法二:
前缀和
给定原数组arr,用sum_arr存储arr数组前面k个元素的和
query计算区间和,只需要sum_arr[j] - sum_arr[i]即可得出,时间复杂度为O(1)
但此时update的时间复杂度会变成O(n)
对比:
所以,当query与update次数很多时,这两种方法的速度都比较慢
1.线段树
给定一数组:
其下标范围为0-5
构造线段树
根节点为0-5号数组元素的总和
其子节点将原数组分成两半,左半部分表示0-2之间所有数字的和,右半部分表示3-5之间所有数字的和
继续分,最终得到如下的一棵树
所有的叶节点为数组中所有的数字
每个中间结点保存其左右子节点的和
这样保存的好处:
假设现在需要计算下标2-5之间的数组和
将[2-5]这段区间,从根节点开始搜索,发现根节点记录的是0-5之间的数字的和,我们可以把这段区间分成两半,左边为[2]单独向左查找,右边为[3-5],右边可以直接找到一个[3-5]的中间结点,即为[3-5]之间的数字之和,其值为27。左边的[2]需要一直找到叶节点,得到其值为5。相加即可得到[2-5]之间的数字和32。这样方法的时间复杂度可以从之前的O(n)降低到O(logn)
对于update操作:
假如将4号位置的9改成6,只需要将9对应的节点(叶节点)改成6,然后依次更新其父节点,直至根节点
其时间复杂度为O(logn)
这棵树近似于完全二叉树,可以使用数组来保存整棵树的节点
将根节点所在数组中的下标记为0,依次标记其他节点。一直到11,都是一棵满二叉树。到最下面一层,节点不满,可以构造几个虚节点,使得下标依次排列下去。如下图
将各节点填入数组中
打叉的位置表示不使用
这样保存的好处:
可以从父节点很方便的找到两个子节点
因为此时有两个数组,原始数组arr,以及保存树节点的tree数组
为了使得下标不至于太混乱,我们将所有与树(tree)相关的下标都加一个node,所有原数组的下标都不带有node。诸如left_node,right_node都表示tree数组的下标,所有start,end,left,right都表示arr数组的下标
源代码:
建树
/**
* 每个节点的值为其左右子节点数值之和,所以可以使用递归
* 先计算出左右子节点的值,再相加得到此节点的值
*
* @param arr 原数组
* @param tree 树节点数组
* @param node 树节点
* @param start arr数组开始位置
* @param end arr数组结束位置
*/
void build_tree(int[] arr, int[] tree, int node, int start, int end) {
//递归出口
if (start == end) {
tree[node] = arr[start];
}
else {
int left_node = 2 * node + 1;
int right_node = 2 * node + 2;
int mid = (start + end) >> 1; //区间下标的中间值
build_tree(arr, tree, left_node, start, mid); //左节点(即左子树的根节点)下标为left_node, 左节点的范围为从start到mid
build_tree(arr, tree, right_node, mid + 1, end); //右节点为mid+1 --> end
//左右子树的值相加即可得到当前节点的值
tree[node] = tree[left_node] + tree[right_node];
}
}
测试:
public class SegmentTree {
// tree数组的最大长度
private static int MAX_LEN = 1000;
public static void main(String[] args) {
int[] arr = {1, 3, 5, 7, 9, 11};
int size = arr.length;
int[] tree = new int[MAX_LEN];
SegmentTree segmentTree = new SegmentTree();
//node:根节点,从0号节点出发
//范围:0 --> size-1
segmentTree.build_tree(arr, tree, 0, 0, size - 1);
for (int i = 0; i < 15; i++) {
System.out.println(tree[i]);
}
}
}
运行结果
0表示此位置不使用(tree数组打叉的地方)。可以看出,这与之前画的那棵树以及tree数组是一致的。
update
void update(int[] arr, int[] tree, int node, int start, int end, int idx, int val) {
//边界条件:刚好到了需要修改的位置
if (start == end) {
arr[idx] = val;
tree[node] = val;
}
else {
int mid = (start + end) >> 2;
int left_node = 2 * node + 1;
int right_node = 2 * node + 2;
//确定该节点落在了左分支还是右分支,只需要看idx是位于[start, mid], 还是[mid+1, end]
if (idx >= start && idx <= mid) {
//更新左分支
update(arr, tree, left_node, start, mid, idx, val);
} else {
//更新右分支
update(arr, tree, right_node, mid + 1, end, idx, val);
}
tree[node] = tree[left_node] + tree[right_node];
}
}
测试
public static void main(String[] args) {
int[] arr = {1, 3, 5, 7, 9, 11};
int size = arr.length;
int[] tree = new int[MAX_LEN];
SegmentTree segmentTree = new SegmentTree();
segmentTree.build_tree(arr, tree, 0, 0, size - 1);
segmentTree.update(arr, tree, 0, 0, size - 1, 4, 6);
for (int i = 0; i < 15; i++) {
System.out.println(tree[i]);
}
}
运行结果:
验证:
结果是一致的
query:计算left到right之间的数字之和
同样也是先计算左子树的和,再计算右子树的和,相加得到当前节点的值
递归出口条件:
-
如果不在范围内的话,直接return0
-
如果到了叶节点,则直接返回叶节点上的数字即可
int query(int[] arr, int[] tree, int node, int start, int end, int left, int right) {
if (right < start || left > end) {
return 0;
} else if (start == end) {
return tree[node];
} else {
int mid = (start + end) >> 1;
int left_node = 2 * node + 1;
int right_node = 2 * node + 2;
int sum_left = query(arr, tree, left_node, start, mid, left, right);
int sum_right = query(arr, tree, right_node, mid + 1, end, left, right);
return sum_left + sum_right;
}
}
测试:
public static void main(String[] args) {
int[] arr = {1, 3, 5, 7, 9, 11};
int size = arr.length;
int[] tree = new int[MAX_LEN];
SegmentTree segmentTree = new SegmentTree();
segmentTree.build_tree(arr, tree, 0, 0, size - 1);
segmentTree.update(arr, tree, 0, 0, size - 1, 4, 6);
int sum = segmentTree.query(arr, tree, 0, 0, size - 1, 2, 5);
for (int i = 0; i < 15; i++) {
System.out.println(tree[i]);
}
System.out.println(sum);
}
运行结果
验证:
结果是正确的
但此时,存在一些问题:
我们在query中打印一下start和end
int query(int[] arr, int[] tree, int node, int start, int end, int left, int right) {
System.out.println("start = " + start);
System.out.println("end = " + end);
System.out.println();
if (right < start || left > end) {
return 0;
} else if (start == end) {
return tree[node];
} else {
int mid = (start + end) >> 1;
int left_node = 2 * node + 1;
int right_node = 2 * node + 2;
int sum_left = query(arr, tree, left_node, start, mid, left, right);
int sum_right = query(arr, tree, right_node, mid + 1, end, left, right);
return sum_left + sum_right;
}
}
public static void main(String[] args) {
int[] arr = {1, 3, 5, 7, 9, 11};
int size = arr.length;
int[] tree = new int[MAX_LEN];
SegmentTree segmentTree = new SegmentTree();
segmentTree.build_tree(arr, tree, 0, 0, size - 1);
segmentTree.update(arr, tree, 0, 0, size - 1, 4, 6);
int sum = segmentTree.query(arr, tree, 0, 0, size - 1, 2, 5);
for (int i = 0; i < 15; i++) {
System.out.println(tree[i]);
}
System.out.println();
System.out.println(sum);
}
运行结果
对比上面两张图,我们可以发现,当运行到[3-5]节点时,我们已经得到了所需的右半部分的和,但此时发现,程序依旧继续运行,接着访问了[3-4], [3-3], [4-4], [5,5],做了很多无用功。
问题出现在递归出口未定义好
之前递归出口,必须递归到叶节点[3-3], [4-4], [5-5]才能结束,但是实际上在[3-5]这个节点就已经可以把所需数值拿到了,所以此时只要判断[3-5]这个范围是否在[left, right]这个范围内即可(子集)
即[start, end]位于[left, end]之间即可
完善终止条件后的代码:
int query(int[] arr, int[] tree, int node, int start, int end, int left, int right) {
System.out.println("start = " + start);
System.out.println("end = " + end);
System.out.println();
if (right < start || left > end) {
return 0;
}else if(left <= start && end <= right){
return tree[node];
} else if (start == end) {
return tree[node];
} else {
int mid = (start + end) >> 1;
int left_node = 2 * node + 1;
int right_node = 2 * node + 2;
int sum_left = query(arr, tree, left_node, start, mid, left, right);
int sum_right = query(arr, tree, right_node, mid + 1, end, left, right);
return sum_left + sum_right;
}
}
运行结果:
此时发现,运行到[3-5]后,便没有继续访问剩余节点
最终附上query的正确代码:
int query(int[] arr, int[] tree, int node, int start, int end, int left, int right) {
if (right < start || left > end) {
return 0;
}else if(left <= start && end <= right){
return tree[node];
} else if (start == end) {
return tree[node];
} else {
int mid = (start + end) >> 1;
int left_node = 2 * node + 1;
int right_node = 2 * node + 2;
int sum_left = query(arr, tree, left_node, start, mid, left, right);
int sum_right = query(arr, tree, right_node, mid + 1, end, left, right);
return sum_left + sum_right;
}
}