Why Segment Tree:
面对这样一个数组,我们要频繁地做“查询区间和”(query)以及“更新某一个值”(update)的操作。对于每一次查询(包含n个元素)和更新,如果采用暴力的方法,那么显然查询时间复杂度为O(n),更新的时间复杂度为O(1);如果我们采用前缀和的方法,那么查询的时间复杂度下降为O(1),但是更新的时间复杂度却上升为O(n)。为了平衡一下这两种操作的复杂度,提出了一种新的数据结构——线段树。
What is Segment Tree:
线段树是一种数据结构,只不过把原本以数组形式存储的数据改为以树存储。如下图:
可见,从根出发,我们存储的是当前区间的区间和。然后把该区间对半分,递归下去,直到区间化为一个点,存储我们原始的数据。
当我们对树进行更新(idx,val)时,把idx与当前区间进行对比,判断我们应当进入哪个分支,递归完之后不要忘记修改当前节点的值。
当我们对树进行查询(L,R)时, 要进行判断:如果搜索区间与当前的区间没有交集,直接返回即可;如果当前区间只有一个值或者当前区间包含在搜索区间内,那么返回该节点的值;否则递归下去。
代码如下:
#include<iostream>
#define MAX_LEN 1000
using namespace std;
void build_tree(int arr[], int tree[], int node, int start, int end){
if(start == end){
//当前区间只有一个值
tree[node] = arr[start];
}else{
int mid = (start + end) / 2;
int left_node = 2 * node + 1;
int right_node = 2 * node + 2;
build_tree(arr, tree, left_node, start, mid);
build_tree(arr, tree, right_node, mid + 1, end);
tree[node] = tree[left_node] + tree[right_node];
}
}
void update_tree(int arr[], int tree[], int node, int start, int end, int idx, int val){
if(start == end){
//当前区间只有一个值
arr[idx] = val;
tree[node] = val;
}else{
int mid = (start + end) / 2;
int left_node = node * 2 + 1;
int right_node = node * 2 + 2;
if(idx >= start && idx <= mid){
update_tree(arr, tree, left_node, start, mid, idx, val);
}else{
update_tree(arr, tree, right_node, mid + 1, end, idx, val);
}
tree[node] = tree[left_node] + tree[right_node];
}
}
int query_tree(int arr[], int tree[], int node, int start, int end, int L, int R){
if(R < start || L > end){
//搜索区间与当前区间无交集
return 0;
}else if(start == end){
//当前区间只有一个值
return tree[node];
}else if(L <= start && end <= R){
//当前区间包含在搜索区间内
return tree[node];
}else{
int mid = (start + end) / 2;
int left_node = node * 2 + 1;
int right_node = node * 2 + 2;
int sum_left = query_tree(arr, tree, left_node, start, mid, L, R);
int sum_right = query_tree(arr, tree, right_node, mid + 1, end, L, R);
return sum_left + sum_right;
}
}
int main(){
int arr[] = {1, 3, 5, 7, 9, 11};
int size = 6;
int tree[MAX_LEN] = {0};
build_tree(arr, tree, 0, 0, size - 1);
for(int i = 0; i < 15; i++)
cout << tree[i] << " ";
cout << endl;
update_tree(arr, tree, 0, 0, size - 1, 4, 6);
for(int i = 0; i < 15; i++)
cout << tree[i] << " ";
cout << endl;
int s = query_tree(arr, tree, 0, 0, size - 1, 2, 5);
cout << "sum: " << s << endl;
return 0;
}