一、线段树简介
线段树是一种二叉树,与堆类似可以利用数组进行存储(下标从1开始)。线段树可以用于维护一个序列的某种属性,比如序列和、序列最大值、序列最小值等等。
二、举例说明
下面举例说明,线段树的结构以及一些操作。考虑这样一个序列:1、2、3、4、5、6、7。我们维护它的序列和。线段树结构如下图所示:
说明:
- 线段树的每个节点包含若干个数字。其中叶节点只包含一个数字
- 儿子节点的长度为父节点的一半。即:若当前序列为 [ L , R ] [L,R] [L,R],且 R − L > 0 R-L>0 R−L>0,则父节点会被对半分为: [ L , M ] [L,M] [L,M]和 [ M + 1 , R ] [M+1, R] [M+1,R]。其中 M = L + R > > 1 M=L+R>>1 M=L+R>>1
三、线段树的操作
1. 单点修改
当序列中某个数被修改时,与其相关的所有区间的和都应该被更新。什么叫“与其相关”呢?就是线段树节点中包含了索引5的节点。举个例子:第5个数被修改为了8,然后还是看图:
具体做法:
- 先判断修改点的索引在哪个范围内,然后在对应的范围内递归地去找修改点,直到找到该点(即叶节点)
- 修改该点后,回溯到根节点,并更新路径中所有节点所存储的和:左儿子节点的和 + 右儿子节点的和
2. 区间查询
比如我们要查询 [ 2 , 5 ] [2,5] [2,5]的区间和,图示如下:
可以看到,当某个节点被完全包含在查询区间内的时候,就不需要继续往下递归了,直接返回其和;否则就一直递归到叶节点。
四、亿点说明
在给出代码之前,还有一个需要说明的部分:存储线段树节点的数组大小应该为多少?
假设序列长度为
n
n
n。在线段树中,倒数第二层的节点个数一定
<
n
<n
<n。从倒数第二层到根节点的节点个数之和:
n
+
n
2
+
n
4
+
n
8
+
.
.
.
+
1
<
n
∗
1
1
−
1
2
=
2
n
n + \frac{n}{2} + \frac{n}{4} + \frac{n}{8} +... +1<n*\frac{1}{1-\frac{1}{2}}=2n
n+2n+4n+8n+...+1<n∗1−211=2n
而最后一层节点个数是倒数第二层的两倍,所以也
<
2
n
<2n
<2n。故总共的节点个数
<
4
n
<4n
<4n。
(参考yxc评论回复)
五、代码详解
1. 线段树数组定义
每个节点存储的是一个区间信息,所以应该包括:区间左端点、区间右端点以及对应的区间信息。
struct Node
{
int l, r;
int sum; // 以维护序列和为例
}tr[4 * N]; // N为序列长度
2. 建立线段树:build
建树实际上就是初始化所有节点的 l 、 r 、 s u m l、r、sum l、r、sum。对于叶节点来说, s u m sum sum即为该点的权值;对于非叶节点来说, s u m sum sum即为左右儿子节点的 s u m sum sum和。
/**
* u : 根节点
* l : 区间左端点
* r : 区间右端点
*/
void build(int u, int l, int r)
{
if (l == r) tr[u] = {l, r, w[l};
else
{
tr[u] = {l, r, 0}; // 初始化该节点
int mid = tr[u].l + tr[u].r >> 1;
build(u << 1, l, mid);
build(u << 1 | 1, mid + 1, r);
tr[u].sum = tr[u << 1].sum + tr[u << 1 | 1].sum;
}
}
2. 单点修改:modify
/**
* u : 根节点下标
* x : 修改点的下标
* v : 修改的值
*/
void modify(int u, int x, int v)
{
if (tr[u].l == tr[u].r) tr[u].sum = v;
else
{
int mid = tr[u].l + tr[u].r >> 1;
if ( x <= mid ) modify(u << 1, x, v); // 在左半区域,就到左儿子去修改
else modify(u << 1 | 1, x, v); // 在右半区域
// 修改完左右儿子后,更新父节点的区间信息
tr[u].sum = tr[u << 1].sum + tr[u << 1 | 1].sum;
}
}
3. 区间查询:query
/**
* u : 根节点
* l : 区间左端点
* r : 区间右端点
*/
int query(int u, int l, int r)
{
if (l <= tr[u].l && tr[u].r <= r) return tr[u].sum;
else
{
int mid = tr[u].l + tr[u].r >> 1;
int sum = 0;
if (l <= mid) sum += query(u << 1, l, mid);
if (r > mid) sum += query(u << 1 | 1, mid + 1, r);
return sum;
}
}