在计算机科学中,Segment Tree也称为统计树,是一种树数据结构,用于存储有关区间或段的信息。它允许查询哪些存储的段包含给定点。原则上,它是一个静态结构;也就是说,它是一种一旦建成就无法修改的结构,一个类似的数据结构是区间树。
一个集合 I 的 n 个间隔的线段树使用 O(n log n)的空间 存储,并且可以在 O(n log n) 时间内构建。线段树支持在O(log n + k) 的时间内搜索包含某个点的所有区间,k 是检索到的区间或段的数量。
要理解线段树,我们先来考虑下面的问题:
我们有一个数组 arr[0 . . . n-1]。我们应该能够
- 求从索引 l 到 r 的元素之和,其中 0 <= l <= r <= n-1
- 将数组的指定元素的值更改为新值x。我们需要做 arr[i] = x 其中 0 <= i <= n-1。
最简单的解决上述问题的方法是常规遍历,但其更新和求和的时间复杂度都是O(n),这显然不是最优的。另一种方法是Prefix Sum的方法,可以在O(1)的时间内得到一个区间的和,但更新操作的时间仍然是O(n)。
想要让更新和求和两个操作的时间复杂度都保持在O(logN),Segment Tree是一个很好的方法。
Segment Tree的表示
- 叶节点是输入数组的元素。
- 每个内部节点代表叶节点的一些合并。对于不同的问题,合并可能会有所不同。大多数情况下,合并是一个节点下的叶子节点的总和。
- 树的数组表示用于表示段树。对于索引 i 处的每个节点,左子节点位于索引(2*i+1)处,右子节点位于(2*i+2)处,父节点位于 (⌊(i – 1) / 2⌋) 处(和普通的二叉树一样)。
简单的Segment Tree
从给定数组中构建一个线段树
我们从段arr[0 开始。. . n-1]。并且每次我们将当前段一分为二(如果它还没有变成长度为 1 的段),然后在分割后的两段上执行相同的操作,对于每个这样的段,我们将其对应的元素总和存储在相应的节点中。
所构建的段树的所有级别都将被完全填充,除了最后一层。此外,这棵树将是一个完整的二叉树,因为我们总是在每一层将段分成两部分。由于构建的树始终是具有 n 个叶子的完整二叉树,因此将有n-1 个内部节点。所以节点的总数将是2*n – 1。
注意:线段树的结果本质上是一个完整的二叉树,二叉树的节点上存储着我们想要的值,因此线段树的高度就是log₂N,由于树是使用数组表示的,并且必须维护父索引和子索引之间的关系,因此为段树分配的内存大小将为(2 * 2 ⌈log 2 n⌉ – 1)。
查询给定区间范围的和
下面是获取元素总和的算法
在上面的实现中,我们需要考虑三种情况
- 如果遍历树时当前节点的范围不在给定范围内,则不会在 ans 中添加该节点的值
- 如果节点范围与给定范围部分重叠,则根据重叠向左或向右移动
- 如果范围与给定范围完全重叠,则将其添加到 ans
更新操作
和线段树构造和查询操作一样,更新也可以递归完成。我们得到了一个需要更新的索引。设diff为要添加的值。我们从线段树的根开始,将diff添加到在其范围内具有给定索引的所有节点。如果一个节点的范围不包含给定的索引,我们不会对该节点进行任何更改。
代码演示(C++)
#include <bits/stdc++.h>
using namespace std;
// A utility function to get the middle index from corner indexes.
int getMid(int s, int e) { return s + (e -s)/2; }
int getSumUtil(int *st, int ss, int se, int qs, int qe, int si)
{
// If segment of this node is a part of given range, then return
// the sum of the segment
if (qs <= ss && qe >= se)
return st[si];
// If segment of this node is outside the given range
if (se < qs || ss > qe)
return 0;
// If a part of this segment overlaps with the given range
int mid = getMid(ss, se);
return getSumUtil(st, ss, mid, qs, qe, 2*si+1) +
getSumUtil(st, mid+1, se, qs, qe, 2*si+2);
}
/* A recursive function to update the nodes which have the given */
void updateValueUtil(int *st, int ss, int se, int i, int diff, int si)
{
// Base Case: If the input index lies outside the range of
// this segment
if (i < ss || i > se)
return;
// If the input index is in range of this node, then update
// the value of the node and its children
st[si] = st[si] + diff;
if (se != ss)
{
int mid = getMid(ss, se);
updateValueUtil(st, ss, mid, i, diff, 2*si + 1);
updateValueUtil(st, mid+1, se, i, diff, 2*si + 2);
}
}
// The function to update a value in input array and segment tree.
// It uses updateValueUtil() to update the value in segment tree
void updateValue(int arr[], int *st, int n, int i, int new_val)
{
// Check for erroneous input index
if (i < 0 || i > n-1)
{
cout<<"Invalid Input";
return;
}
// Get the difference between new value and old value
int diff = new_val - arr[i];
// Update the value in array
arr[i] = new_val;
// Update the values of nodes in segment tree
updateValueUtil(st, 0, n-1, i, diff, 0);
}
// Return sum of elements in range from index qs (query start)
// to qe (query end). It mainly uses getSumUtil()
int getSum(int *st, int n, int qs, int qe)
{
// Check for erroneous input values
if (qs < 0 || qe > n-1 || qs > qe)
{
cout<<"Invalid Input";
return -1;
}
return getSumUtil(st, 0, n-1, qs, qe, 0);
}
// A recursive function that constructs Segment Tree for array[ss..se].
// si is index of current node in segment tree st
int constructSTUtil(int arr[], int ss, int se, int *st, int si)
{
// If there is one element in array, store it in current node of
// segment tree and return
if (ss == se)
{
st[si] = arr[ss];
return arr[ss];
}
// If there are more than one elements, then recur for left and
// right subtrees and store the sum of values in this node
int mid = getMid(ss, se);
st[si] = constructSTUtil(arr, ss, mid, st, si*2+1) +
constructSTUtil(arr, mid+1, se, st, si*2+2);
return st[si];
}
/* Function to construct segment tree from given array. This function
allocates memory for segment tree and calls constructSTUtil() to
fill the allocated memory */
int *constructST(int arr[], int n)
{
// Allocate memory for the segment tree
//Height of segment tree
int x = (int)(ceil(log2(n)));
//Maximum size of segment tree
int max_size = 2*(int)pow(2, x) - 1;
// Allocate memory
int *st = new int[max_size];
// Fill the allocated memory st
constructSTUtil(arr, 0, n-1, st, 0);
// Return the constructed segment tree
return st;
}
// Driver program to test above functions
int main()
{
int arr[] = {1, 3, 5, 7, 9, 11};
int n = sizeof(arr)/sizeof(arr[0]);
// Build segment tree from given array
int *st = constructST(arr, n);
// Print sum of values in array from index 1 to 3
cout<<"Sum of values in given range = "<<getSum(st, n, 1, 3)<<endl;
// Update: set arr[1] = 10 and update corresponding
// segment tree nodes
updateValue(arr, st, n, 1, 10);
// Find sum after the value is updated
cout<<"Updated sum of values in given range = "
<<getSum(st, n, 1, 3)<<endl;
return 0;
}
//This code is contributed by Snros(嗅探网)
除了上面所说的基本操作外,Segment Tree还可以做范围最小查询,相关的具体算法,我们会在以后的文章中讨论。
其他语言实现下载链接:
(包含各种语言:C语言、Python、Java、C++等均有示例)
免费资源下载:Segment Tree