高级数据结构之线段树
1.前缀和
给定一个数组arr,数组可能非常大。在程序运行过程中,你可能要做好几次query和update操作: query(arr, L, R) 表示计算数组arr中,从下标L到下标R之间的所有数字的和。 update(arr, i, val) 表示要把arr[i]中的数字改成val。 怎样尽可能快地完成一系列query和update的操作?
如果是前缀和,则只是优化了query的时间复杂度为 O(1),而update的操作仍然是O(n),因为更新了某个值,区间的presum全部要更新
那么如何降低update的时间复杂度呢?同时又让query的时间复杂度不会太高?答案就是线段树,它能让query和update的操作都变成O(logN).
2.什么是线段树?
线段树是一种二叉搜索树,与区间树相似,它将一个区间划分成一些单元区间,每个单元区间对应线段树中的一个叶结点。使用线段树可以快速的查找某一个节点在若干条线段中出现的次数,时间复杂度为O(logN)。而未优化的空间复杂度为2N,实际应用时一般还要开4N(叶节点是n,那么整个树的节点个数 2^h - 1, 2^(h - 1) = n, 防止越界的情况下,叶子节点左右自空节点也需要空间, 那么就是 2n + 2n = 4n)的数组以免越界,因此有时需要离散化让空间压缩。
对于线段树中的每一个非叶子节点[a,b],它的左儿子表示的区间为[a,(a+b)/2],右儿子表示的区间为[(a+b)/2+1,b]。因此线段树是平衡二叉树,最后的子节点数目为N,即整个线段区间的长度。
如下图所示,将整个区间不断的去分成若干个区间,每个非叶子节点代表了某个区间的和,叶子节点就是arr的对应下标的值
线段树的数组下标与值对应的关系
3.如何建立线段树?
提示:代码中的寻找区间为[start, end], 数组的范围为[l, r],与图示相反
1.建立树
递归的思想,自顶向下,每个树的父节点的值是左右子节点的和,因此可以采用递归的思想,边界就是 l == r,代表区间只有一个值就返回,区间和自底向上不断更新
public void buildTree(int[] tree, int[] data, int l, int r, int treeIndex){
if(l == r){
tree[treeIndex] = data[r];
return;
}
int mid = (l + r) / 2;
//子节点的索引
int leftIndex = 2 * treeIndex + 1;
int rightIndex = 2 * treeIndex + 2;
buildTree(tree, data, l, mid, leftIndex);
buildTree(tree, data, mid + 1, r, rightIndex);
tree[treeIndex] = tree[leftIndex] + tree[rightIndex];
}
2.修改树
得先知道要修改得index在树得哪个半区,然后同样采用递归得思想进行修改,自底向上进行不断更新
public void updateTree(int[] tree, int[] data, int l, int r, int treeIndex, int val, int index){
if(l == r){
data[l] = val;
tree[treeIndex] = val;
return;
}
int mid = (l + r) / 2;
int leftIndex = 2 * treeIndex + 1;
int rightIndex = 2 * treeIndex + 2;
//递归左右子树
if(index >= l && index <= mid){
updateTree(tree, data, l, mid, leftIndex, val, index);
}else{
updateTree(tree, data, mid + 1, r, val, index);
}
//不断往上更新
tree[treeIndex] = tree[leftIndex] + tree[rightIndex];
}
3.寻找相应的值
1.查询的区间范围[L,R]不在数组范围内,说明区间没有重合,没有重合的话,即使再继续递归下去也没有值,因此此时需要返回0
2.当l == r时,表示某个半区只有一个值,当 L <= start end <= R时,表示区间覆盖直接返回当前树根节点的值
比如求区间[2,5],那么当l == r时,则左半区找的是[2,2],右半区找的是[3,5]
3.数组范围包含于寻找区间直接返回树节点值。仍是以[2,5]为例,当右半区为[3,5]区间了,它是包含于我需要寻找的整个区间[2,5]的, 那么这个时候可以直接返回[3,5]区间的值,若不然,它会继续递归搜索左右子树,然后会有重复计算,那么树的非叶子节点就无意义了,所以必须提前减枝,返回。.
如图是没有提前减枝的情况
(代码所寻找区间为[start, end],对应图示的[L , R],数组范围[l, r]对应图示的[start, end],正好相反)
一直会向根节点递归, 直到 l == r才返回,因此可以提前减枝加上此终止条件 : if(start <= l && r <= end)return tree[treeIndex];
//start end是要寻找的数组区间, l, r是数组的长度范围
public int queryTree(int[] tree, int treeIndex, int l, int r, int start, int end){
if(l > end || r < start)return 0;//超出范围则返回0
//叶子节点
if(l == r)return tree[treeIndex];
//当前l,r落在了查询区间,则直接返回当前节点,表示区间和的一部分
//不加这一步会重复计算
if(start <= l && r <= end)return tree[treeIndex];
int mid = (l + r) / 2;
int leftIndex = 2 * treeIndex + 1;
int rightIndex = 2 * treeIndex + 2;
int leftSum = queryTree(tree, leftIndex, l, mid, start, end);
int rightSum = queryTree(tree, rightIndex, mid + 1, r, start, end);
return leftSum + rightSum;
}
4.LeetCode
307. 区域和检索 - 数组可修改
class NumArray {
private int[] data;
private int[] tree;
public NumArray(int[] nums) {
this.data = nums;
this.tree = new int[nums.length * 4];
buildTree(0, 0, nums.length - 1);
}
public void buildTree(int treeIndex, int l, int r){
if(l == r){
tree[treeIndex] = data[l];//叶子节点
return;
}
int mid = (l + r) / 2;
int leftIndex = 2 * treeIndex + 1;
int rightIndex = 2 * treeIndex + 2;
buildTree(leftIndex, l, mid);
buildTree(rightIndex,mid + 1, r);
tree[treeIndex] = tree[leftIndex] + tree[rightIndex];
}
public void update(int index, int val){
updateTree(0, 0, data.length - 1, index, val);
}
public void updateTree(int treeIndex, int l, int r, int index, int val){
if(l == r && l == index){
data[index] = val;
tree[treeIndex] = val;
return;
}
int mid = (l + r) / 2;
int leftIndex = 2 * treeIndex + 1;
int rightIndex = 2 * treeIndex + 2;
if(index >= l && index <= mid){
updateTree(leftIndex, l, mid, index, val);
}else{
updateTree(rightIndex, mid + 1, r, index, val);
}
tree[treeIndex] = tree[leftIndex] + tree[rightIndex];
}
public int sumRange(int start, int end){
return queryTree(0, 0, data.length - 1, start, end);
}
public int queryTree(int treeIndex, int l, int r, int start, int end){
if(l > end || r < start)return 0;
if(l == r)return tree[treeIndex];
if(start <= l && r <= end)return tree[treeIndex];
int mid = (l + r) / 2;
int leftIndex = 2 * treeIndex + 1;
int rightIndex = 2 * treeIndex + 2;
int leftSum = queryTree(leftIndex, l, mid, start, end);
int rightSum = queryTree(rightIndex, mid + 1, r, start, end);
return leftSum + rightSum;
}
}