线段树
线段树(Segment Tree)是稍高级一点的数据结构,它一般用于维护区间信息。线段树是一棵平衡二叉树,其根结点代表着整个区间的信息,越往下的结点代表的区间越小,也就是说,线段树的每一个结点都对应着一条区间(线段)。
线段树的建立
如果有一个数组是[1,2,3,4,5,6,7,8],那么它对应的线段树大致如下:
我们从下标1开始存储每个节点(较为方便),这样每个节点 x x x的左儿子节点为 2 x 2x 2x,右孩子为 2 x + 1 2x+1 2x+1。假设 x x x结点存储的是区间 [ l e f t , r i g h t ] [left, right] [left,right]的信息, mid = ⌊ l e f t + r i g h t 2 ⌋ \operatorname{mid}=\left\lfloor\frac{left+right}{2}\right\rfloor mid=⌊2left+right⌋,那么其左右儿子存储的分别是区间 [ l e f t , m i d ] [left, mid] [left,mid]和区间 [ m i d + 1 , r i g h t ] [mid+1, right] [mid+1,right]的信息。可以发现,由于 m i d mid mid的计算左节点对应的区间长度,与右节点相同或者比之恰好多1。
例题1
我们从最简单的例题来编写一下线段树的代码:
老师想知道从某某同学当中,分数最高的是多少,现在请你编程模拟老师的询问。当然,老师有时候需要更新某位同学的成绩.
**输入描述:**
每组输入第一行是两个正整数N和M(0 < N <= 30000,0 < M < 5000),分别代表学生的数目和操作的数目。
学生ID编号从1编到N。
第二行包含N个整数,代表这N个学生的初始成绩,其中第i个数代表ID为i的学生的成绩
接下来又M行,每一行有一个字符C(只取‘Q’或‘U’),和两个正整数A,B,当C为'Q'的时候, 表示这是一条询问操作,假设A<B,他询问ID从A到B(包括A,B)的学生当中,成绩最高的是多少
当C为‘U’的时候,表示这是一条更新操作,要求把ID为A的学生的成绩更改为B。
注意:输入包括多组测试数据。
**输出描述:**
对于每一次询问操作,在一行里面输出最高成绩.
那么,线段树每个结点存储的便是当前区间的最大值,我们递归构造此线段树,代码如下:
void build(int pos, int cur_left, int cur_right) {
if (cur_left == cur_right) {
value[pos] = ::data[cur_left];
return;
}
int m = (cur_left + cur_right) >> 1;
int lchild = pos << 1, rchild = pos << 1 | 1;
build(lchild, cur_left, m);
build(rchild, m + 1, cur_right);
value[pos] = max(value[lchild], value[rchild]);
}
线段树的更新
假如我要更新 p p p结点的值,那么包含 p p p结点的所有区间的结点均需被更新,也是采用递归的形式进行更新,代码如下:
void update(int idx, int new_value, int pos, int cur_left, int cur_right) {
if (cur_left == cur_right && cur_left == idx) {
value[pos] = new_value;
return;
}
int m = (cur_left + cur_right) >> 1;
if (idx <= m) update(idx, new_value, pos << 1, cur_left, m);
if (m < idx) update(idx, new_value, pos << 1 | 1, m + 1, cur_right);
value[pos] = max(value[pos << 1], value[pos << 1 | 1]);
}
线段树的查询
那此时如果我们需要查询指定区间的信息呢?依旧是递归查询,我们先贴出代码:
int querymax(int L, int R, int pos, int cur_left, int cur_right) {
if (L <= cur_left && cur_right <= R) {
return value[pos];
}
int m = (cur_left + cur_right) >> 1;
int lans = -1, rans = -1;
if (L <= m) lans = querymax(L, R, pos << 1, cur_left, m);
if (m < R) rans = querymax(L, R, pos << 1 | 1, m + 1, cur_right);
if (lans == -1)return rans;
if (rans == -1)return lans;
return max(lans, rans);
}
查询区间的时候,有三种情况。假设需要查询的区间为 [ L , R ] [L, R] [L,R],若是目标区间覆盖了当前区间,那么当前区间的最大值是需要的,直接返回。若没有完全覆盖,且若 L L L在 m i d mid mid的左边,那么需要去左孩子处( [ c u r _ l e f t , m i d ] [cur\_left, mid] [cur_left,mid])查询目标区间需要的信息,不然该值取-1。同理,若 R R R在 m i d mid mid的右边,则需要去右孩子处( [ m i d + 1 , c u r _ r i g h t ] [mid + 1, cur\_right] [mid+1,cur_right])查询目标区间需要的信息,不然该值取-1。上述两者中,若有一者为-1,则直接返回另一者。如果都不是-1,则需要返回这二者之间的较大值。
例题1代码
例题1的整体代码如下:
#include <stdio.h>
#include <algorithm>
using namespace std;
const int MAXN = 100000;
int data[MAXN + 5];
int value[MAXN * 4 + 5];
void build(int pos, int left, int right) {
if (left == right) {
value[pos] = ::data[left];
return;
}
int m = (left + right) >> 1;
int lchild = pos << 1, rchild = pos << 1 | 1;
build(lchild, left, m);
build(rchild, m + 1, right);
value[pos] = max(value[lchild], value[rchild]);
}
int querymax(int L, int R, int pos, int left, int right) {
if (L <= left && right <= R) {
return value[pos];
}
int m = (left + right) >> 1;
int lans = -1, rans = -1;
if (L <= m) lans = querymax(L, R, pos << 1, left, m);
if (m < R) rans = querymax(L, R, pos << 1 | 1, m + 1, right);
if (lans == -1)return rans;
if (rans == -1)return lans;
return max(lans, rans);
}
void update(int idx, int new_value, int pos, int left, int right) {
if (left == right && left == idx) {
value[pos] = new_value;
return;
}
int m = (left + right) >> 1;
if (idx <= m) update(idx, new_value, pos << 1, left, m);
if (m < idx) update(idx, new_value, pos << 1 | 1, m + 1, right);
value[pos] = max(value[pos << 1], value[pos << 1 | 1]);
}
int main() {
int n, m;
while (~scanf("%d%d", &n, &m)) {
for (int i = 1; i <= n; i++) {
scanf("%d", &::data[i]);
}
build(1, 1, n);
char order;
int a, b;
for (; m--;) {
scanf(" %c%d%d", &order, &a, &b);
if (order == 'U') {
update(a, b, 1, 1, n);
}
else if (order == 'Q') {
if (a > b)swap(a, b);
printf("%d\n", querymax(a, b, 1, 1, n));
}
}
}
return 0;
}
代码中线段树申明了4倍的空间,是因为防止越界。比如区间长度为 n n n,那么线段树的最后一层便为 n n n个结点,那么该线段树的高度便为 ⌈ log 2 n ⌉ \left\lceil\log _{2} n\right\rceil ⌈log2n⌉,不难看出, ⌈ log 2 n ⌉ ≤ log 2 n + 1 \left\lceil\log _{2} n\right\rceil \leq \log _{2} n+1 ⌈log2n⌉≤log2n+1。通过等比数列求和,可得整棵树的节点数量为 1 ∗ ( 1 − 2 x ) 1 − 2 \frac{1 *\left(1-2^{x}\right)}{1-2} 1−21∗(1−2x),其中 x x x为高度,整理可得 2 log 2 n + 1 + 1 − 1 2^{\log _{2} n+1+1}-1 2log2n+1+1−1, − 1 -1 −1忽略,可得 4 n 4n 4n。
例题2
对于有些情况而言,可以添加懒标记。如下例题:
题目描述
如题,已知一个数列,你需要进行下面两种操作:
1.将某区间每一个数加上x
2.求出某区间每一个数的和
输入格式
第一行包含两个整数N、M,分别表示该数列数字的个数和操作的总个数。
第二行包含N个用空格分隔的整数,其中第i个数字表示数列第i项的初始值。
接下来M行每行包含3或4个整数,表示一个操作,具体如下:
操作1: 格式:1 x y k 含义:将区间[x,y]内每个数加上k
操作2: 格式:2 x y 含义:输出区间[x,y]内每个数的和
输出格式
输出包含若干行整数,即为所有操作2的结果。
该题在更新的时候,需要对一整个区间进行更新,如果每次都像上面一个一个元素的更新,那么复杂度较高,所以可以引入懒标记。
懒标记
更新的时候,对于那些正好是线段树结点的区间,我们不往下递归,我们给其一个懒标记。将来要用到它的子区间的时候,再向下传递。比如在上面画的图中,我们要更新区间 [ 3 , 7 ] [3, 7] [3,7]的值,我们要涉及到的结点有:结点1 [ 1 , 8 ] [1, 8] [1,8],结点2 [ 1 , 4 ] [1, 4] [1,4],结点5 [ 3 , 4 ] [3, 4] [3,4],结点10 [ 3 , 3 ] [3, 3] [3,3],结点11 [ 4 , 4 ] [4, 4] [4,4],结点3 [ 5 , 8 ] [5, 8] [5,8],结点6 [ 5 , 6 ] [5, 6] [5,6],结点12 [ 5 , 5 ] [5, 5] [5,5],结点13 [ 6 , 6 ] [6, 6] [6,6],结点7 [ 7 , 8 ] [7, 8] [7,8],结点14 [ 7 , 7 ] [7, 7] [7,7]。其中,结点5,结点6和节点14的时候,区间就是该树结点,所以我们赋予其一个懒标记,就无需向下递归了,这样结点10,结点11,结点12,结点13便无需更新。如果以后需要用到这些节点,那么肯定会先一步访问到它们的父结点,若是它们的父结点有懒标记,再将懒标记往下更新(push_down)。
例题2代码
这里我直接粘贴一下知乎上一位大佬的代码。
using namespace std;
using ll = long long;
const int MAXN = 1e5 + 5;
ll tree[MAXN << 2], mark[MAXN << 2], n, m, A[MAXN];
void push_down(int p, int len)
{
if (len <= 1) return;
tree[p << 1] += mark[p] * (len - len / 2);
mark[p << 1] += mark[p];
tree[p << 1 | 1] += mark[p] * (len / 2);
mark[p << 1 | 1] += mark[p];
mark[p] = 0;
}
void build(int p = 1, int cl = 1, int cr = n)
{
if (cl == cr) return void(tree[p] = A[cl]);
int mid = (cl + cr) >> 1;
build(p << 1, cl, mid);
build(p << 1 | 1, mid + 1, cr);
tree[p] = tree[p << 1] + tree[p << 1 | 1];
}
ll query(int l, int r, int p = 1, int cl = 1, int cr = n)
{
if (cl >= l && cr <= r) return tree[p];
push_down(p, cr - cl + 1);
ll mid = (cl + cr) >> 1, ans = 0;
if (mid >= l) ans += query(l, r, p << 1, cl, mid);
if (mid < r) ans += query(l, r, p << 1 | 1, mid + 1, cr);
return ans;
}
void update(int l, int r, int d, int p = 1, int cl = 1, int cr = n)
{
if (cl >= l && cr <= r) return void(tree[p] += d * (cr - cl + 1), mark[p] += d);
push_down(p, cr - cl + 1);
int mid = (cl + cr) >> 1;
if (mid >= l) update(l, r, d, p << 1, cl, mid);
if (mid < r) update(l, r, d, p << 1 | 1, mid + 1, cr);
tree[p] = tree[p << 1] + tree[p << 1 | 1];
}
int main()
{
ios::sync_with_stdio(false);
cin >> n >> m;
for (int i = 1; i <= n; ++i)
cin >> A[i];
build();
while (m--)
{
int o, l, r, d;
cin >> o >> l >> r;
if (o == 1)
cin >> d, update(l, r, d);
else
cout << query(l, r) << '\n';
}
return 0;
}
线段树的一些问题
- 假设数组元素个数为
n
n
n,那么线段树需要多少个结点?
在本文上面写了,数组元素个数为 n n n,那么最底层的结点个数便为 n n n,用等比数列求和可以得到,线段树的结点数量为 2 n − 1 2n-1 2n−1。 - 线段树是完全二叉树吗?
有些情况不是,比如区间为 [ 1 , 6 ] [1, 6] [1,6]的线段树。 - 用数组储存线段树,1为根结点,
2
n
2n
2n个树结点大小的空间够了吗?
上面也给出了证明,最好声明 4 n 4n 4n的空间。 - 线段树含有度为1的结点吗?
没有。一个结点,若是不可分,则它是叶子结点,度为0;若可分,则肯定分为两部分,度为2。 - 叶结点和树结点区间长度有怎样的关系?
某个树结点的区间长度等于该结点构成的子树的所有叶结点个数之和 - 树结点的左右子树有怎样的关系?
左子树和右子树区间长度之差总是小于等于1,即叶结点个数之差总是小于等于1 - 线段树是平衡二叉树吗?
是。如问题4,线段树的结点的度要么为0,要么为2。