线段树是二叉树的一种,常常被用于求区间和与区间最大值等操作。
场景假设
想象一下这种场景,一个数组现在需要进行两种操作:对某区间内数字进行求和与更新数组中的某个元素。对于更新元素操作很容易看出它的时间复杂度是 O ( 1 ) O(1) O(1),区间求和的时间复杂度是 O ( n ) O(n) O(n)。
通过下面这种方法可以降低区间元素求和的时间复杂度:设置另一个数组sum_table
,这个数组每个位置的值存储table
数组中当前位置及之前的所有元素之和。这样,我们可以把区间元素求和的时间复杂度降为
O
(
1
)
O(1)
O(1),以区间[1,4]为例,计算方法为sum_table[4]-sum_table[0]
。但这时更改元素的时间复杂度又上升了,修改了下标为8的元素,那么在sum_table
中需要把下标为8及之后的元素都要进行修改。
如果我们想把这两个操作的时间复杂度平均一下就需要用到今天介绍的线段树了。简单起见,我们用个短一点的数组来进行表示。
上面这幅图,根节点是整个[0-5]范围内的区间元素之和,左孩子是前一半的数组和,右孩子是后一半的数组和,以此类推,叶子节点是每一个元素的值。这样不管是更新元素值还是求区间元素和都会是 O ( l o g n ) O(logn) O(logn)的复杂度。
为了方便后面进行代码表示,我们把上面那幅图稍微修改一下。给每个结点从上至下、从左至右开始编号,并且补齐空缺的结点使其成为一棵完全二叉树,根据编号将这些结点填入数组tree
中。
代码实现
首先看一下构建线段树的代码实现。很明显这是一个递归调用的过程,由于每次将arr
一分为二,因此递归出口就比较明显了,就是当区间的左边界与右边界相同时,也就是
s
t
a
r
t
=
=
e
n
d
start == end
start==end时,进行tree
数组填充。递归填充结点的左子树和右子树然后将左右孩子的值加起来得到当前结点的值。left_node
和right_node
的值是根据完全二叉树的性质得到的。
public void buildTree(int[] arr, int[] tree, int node, int start, int end) {
if (start == end) {
tree[node] = arr[start];
} else {
int mid = start + ((end - start) >> 1);
int left_node = 2 * node + 1;
int right_node = 2 * node + 2;
buildTree(arr, tree, left_node, start, mid);
buildTree(arr, tree, right_node, mid + 1, end);
tree[node] = tree[left_node] + tree[right_node];
}
}
接着是更新结点的代码,同样也是递归调用,递归出口和之前一样,还是区间边界相等时递归结束。二分查找判断需要更新哪一边的元素,找到更新结点并更新完成后逐级向上更新上层的结点。
public void updateTree(int[] arr, int[] tree, int node, int start, int end, int idx, int val) {
if (start == end) {
arr[idx] = val;
tree[node] = val;
} else {
int mid = start + ((end - start) >> 1);
int left_node = 2 * node + 1;
int right_node = 2 * node + 2;
if (idx >= start && idx <= mid) {
updateTree(arr, tree, left_node, start, mid, idx, val);
} else {
updateTree(arr, tree, right_node, mid + 1, end, idx, val);
}
tree[node] = tree[left_node] + tree[right_node];
}
}
下面是区间求和部分的代码,依然是递归实现的代码。我们还是从根节点出发,求出左边的和sum_left
与右边的和sum_right
,然后加起来就是区间和了。当然左边或右边也许并不在区间内,这就需要考虑递归出口了。首先我们需要考虑待求和区间与当前递归区间不重合的情况,这时返回0,结束当前递归。接着是当前区间是待求和区间的子区间的情况,直接返回当前区间的和,也就是当前结点的值。第三种情况其实可以算是第二种情况的特殊情形,也就是遇到结点是待求和区间中的元素,同样返回结点值。
public int queryTree(int[] arr, int[] tree, int node, int start, int end, int L, int R) {
if (end < L || start > R) {
return 0;
} else if (L <= start && end <= R) {
return tree[node];
} else if (start == end) {
return tree[node];
} else {
int mid = start + ((end - start) >> 1);
int left_node = 2 * node + 1;
int right_node = 2 * node + 2;
int sum_left = queryTree(arr, tree, left_node, start, mid, L, R);
int sum_right = queryTree(arr, tree, right_node, mid + 1, end, L, R);
return sum_left + sum_right;
}
}
下面是测试代码的部分。将数组arr
构建成一棵线段树,需要注意的部分是tree_size
也就是tree
长度的计算,首先需要求出线段树的高度
h
h
h为
⌈
l
o
g
2
s
i
z
e
⌉
+
1
\lceil log_2{size} \rceil+1
⌈log2size⌉+1,然后根据完全二叉树的性质得到tree_size
为
2
h
−
1
2^h-1
2h−1。
public static void main(String[] args) {
int[] arr = {1, 3, 5, 7, 9, 11};
int size = 6;
int tree_size = (int) (Math.pow(2, Math.ceil(Math.log(size) / Math.log(2)) + 1) - 1);
int[] tree = new int[tree_size];
SegmentTree sTree = new SegmentTree();
sTree.buildTree(arr, tree, 0, 0, size - 1);
for (int i = 0; i < tree_size; ++i) {
System.out.printf("tree[%d] = %d\n", i, tree[i]);
}
System.out.println();
sTree.updateTree(arr, tree, 0, 0, size - 1, 4, 6);
for (int i = 0; i < tree_size; ++i) {
System.out.printf("tree[%d] = %d\n", i, tree[i]);
}
int s = sTree.queryTree(arr, tree, 0, 0, size - 1, 2, 5);
System.out.println("s = " + s);
}
下面是线段树的完整Java代码实现。
public class SegmentTree {
public void buildTree(int[] arr, int[] tree, int node, int start, int end) {
if (start == end) {
tree[node] = arr[start];
} else {
int mid = start + ((end - start) >> 1);
int left_node = 2 * node + 1;
int right_node = 2 * node + 2;
buildTree(arr, tree, left_node, start, mid);
buildTree(arr, tree, right_node, mid + 1, end);
tree[node] = tree[left_node] + tree[right_node];
}
}
public void updateTree(int[] arr, int[] tree, int node, int start, int end, int idx, int val) {
if (start == end) {
arr[idx] = val;
tree[node] = val;
} else {
int mid = start + ((end - start) >> 1);
int left_node = 2 * node + 1;
int right_node = 2 * node + 2;
if (idx >= start && idx <= mid) {
updateTree(arr, tree, left_node, start, mid, idx, val);
} else {
updateTree(arr, tree, right_node, mid + 1, end, idx, val);
}
tree[node] = tree[left_node] + tree[right_node];
}
}
public int queryTree(int[] arr, int[] tree, int node, int start, int end, int L, int R) {
if (end < L || start > R) {
return 0;
} else if (L <= start && end <= R) {
return tree[node];
} else if (start == end) {
return tree[node];
} else {
int mid = start + ((end - start) >> 1);
int left_node = 2 * node + 1;
int right_node = 2 * node + 2;
int sum_left = queryTree(arr, tree, left_node, start, mid, L, R);
int sum_right = queryTree(arr, tree, right_node, mid + 1, end, L, R);
return sum_left + sum_right;
}
}
public static void main(String[] args) {
int[] arr = {1, 3, 5, 7, 9, 11};
int size = 6;
int tree_size = (int) (Math.pow(2, Math.ceil(Math.log(size) / Math.log(2)) + 1) - 1);
int[] tree = new int[tree_size];
SegmentTree sTree = new SegmentTree();
sTree.buildTree(arr, tree, 0, 0, size - 1);
for (int i = 0; i < tree_size; ++i) {
System.out.printf("tree[%d] = %d\n", i, tree[i]);
}
System.out.println();
sTree.updateTree(arr, tree, 0, 0, size - 1, 4, 6);
for (int i = 0; i < tree_size; ++i) {
System.out.printf("tree[%d] = %d\n", i, tree[i]);
}
int s = sTree.queryTree(arr, tree, 0, 0, size - 1, 2, 5);
System.out.println("s = " + s);
}
}
当然除了本文中用于求区间和的操作,线段树还可以求区间最大值,只需要将结点存放的值由子树和改成子树最大值即可。