线段树
~
有什么问题欢迎在评论区讨论~
问题引入
n个数字,m次查询。
查询方式:l,r
即求区间[l,r]的和/最大值/最小值
使用线段树
为了高效率地解决这个问题,我们需要使用线段树。
如图所示,所谓线段树就是将一个大区间分成两个小区间,小区间再分两个小小区间,直到区间中仅剩一个元素为止。
线段树有两种,一种比较勤(hen)奋(man), 另一种比较懒,代码上差别不大,所以直接就懒一点吧~~
需要的变量
int Sum[N << 2], Add[N << 2];
// Sum: 线段树数组
// Add: lazy标记(懒懒的线段树)
int A[N], n;
// A: 原数组
// n: 原数组大小
建树
在建树的时候我们用数组来模拟二叉树以节省时间空间。
对于一棵以
r
t
(
r
t
≠
0
)
rt(rt \neq0)
rt(rt=0) 为根节点的二叉树而言,其左右儿子在数组中分别为
r
t
∗
2
rt * 2
rt∗2 和
r
t
∗
2
+
1
rt * 2 + 1
rt∗2+1 。
我们用位运算来表(zhuang)示(bi),即
r
t
<
<
1
rt << 1
rt<<1 和
r
t
<
<
1
∣
1
rt << 1 | 1
rt<<1∣1 。
此外,我们把更新当前节点的操作抽取出来作为一个函数pushUp
,以减少代码重复出现。
// 更新当前节点
void pushUp(int rt) {
Sum[rt] = Sum[rt << 1] + Sum[rt << 1 | 1];
}
// 建树
void Build(int l, int r, int rt) {
if (l == r) {
Sum[rt] = A[l];
return;
}
int mid = (l + r) >> 1;
// 递归建树
Build(l, mid, rt << 1);
Build(mid + 1, r, rt << 1 | 1);
// 建树完成后需要更新当前节点(因为当前节点的子节点的值可能改变,导致当前节点的值需要改变)
pushUp(rt);
}
单点更新
现在我们需要更新一个节点的值,在该节点的值更新之后要干啥?更新他所在的小小区间和小区间和大区间的值喽~
在这里我们递归地进行更新。
/// x is the node needing to be changed.
/// [l, r] is the region now.
void udNode(int x, int val, int l, int r, int rt) {
if (l == r) {
Sum[rt] += val;
return;
}
int mid = (l + r) << 1;
// 确定x节点所在的区间
if (x <= mid) udNode(x, val, l, mid, rt << 1);
else udNode(x, val, mid + 1, r, rt << 1 | 1);
pushUp(rt);
}
区间更新
如果我们需要统一地更新一段区间的值呢?比如某一段区间的数同时增减某个数。
这时一个一个来就比较慢了,我们需要整体进行。
寻找目标区间和大区间、小区间、小小区间的交集区间,然后更新~
/// [L, R] is the wanted region.
/// [l, r] is the region now.
// 这个范围很容易混,建议画图明确一下
void udRegion(int L, int R, int val, int l, int r, int rt) {
// 在当前区间包含于目标区间时,直接更新就好~
if (L <= l && r <= R) {
Sum[rt] += val * (r - l + 1);
Add[rt] += val;
return;
}
// 否则寻找小区间和小小区间~
int mid = (l + r) >> 1;
// 下推lazy标记(下文解释~)
pushDown(rt, mid - l + 1, r - mid);
// 这里自己画个图比我啰嗦好懂多了~
if (L <= mid) udRegion(L, R, val, l, mid, rt << 1);
if (R > mid) udRegion(L, R, val, mid + 1, r, rt << 1 | 1);
pushUp(rt);
}
关于上文的pushDown
。
我们在进行区间更新后会进行查询操作,查询的范围很有可能把更新的范围包含进去,那么这个时候查询到该区间时直接返回更新后的值就好,不会涉及到小区间和小小区间的值。
所以我们就可以懒一点,暂时记录下来这里待更新而暂不进行更新,当碰到需要小区间和小小区间的时候再更新,以节省时间(很可能接下来的若干查询都不需要小区间和小小区间,更新也是挺费时间的)。
/// ln is the number of the nodes of the Left Subtree
/// rn is the number of the nodes of the Right Subtree
// 下推lazy标记
void pushDown(int rt, int ln, int rn) {
if (Add[rt]) {
Add[rt << 1] += Add[rt];
Add[rt << 1 | 1] += Add[rt];
Sum[rt << 1] += Add[rt] * ln;
Sum[rt << 1 | 1] += Add[rt] * rn;
Add[rt] = 0;
}
}
区间查询
查询区间 [ l , r ] [l, r] [l,r] 的区间和。
这里如果碰到了需要小区间/小小区间的情况,就下推lazy标记。
// 查询区间和
int Query(int L, int R, int l, int r, int rt) {
if (L <= l && r <= R) return Sum[rt];
int mid = (l + r) >> 1;
pushDown(rt, mid - l + 1, r - mid);
int ret = 0;
if (L <= mid) ret += Query(L, R, l, mid, rt << 1);
if (R > mid) ret += Query(L, R, mid + 1, r, rt << 1 | 1);
return ret;
}
完整代码
#include <bits/stdc++.h>
using namespace std;
const int N = 1e5 + 10;
int Sum[N << 2], Add[N << 2];
// Sum: 线段树数组
// Add: lazy标记
int A[N], n;
// A: 原数组
// n: 原数组大小
// 更新当前节点
void pushUp(int rt) {
Sum[rt] = Sum[rt << 1] + Sum[rt << 1 | 1];
}
// 建树
void Build(int l, int r, int rt) {
if (l == r) {
Sum[rt] = A[l];
return;
}
int mid = (l + r) >> 1;
// 递归建树
Build(l, mid, rt << 1);
Build(mid + 1, r, rt << 1 | 1);
// 建树完成后需要更新当前节点(因为当前节点的子节点的值可能改变,导致当前节点的值需要改变)
pushUp(rt);
}
/// x is the node needing to be changed.
/// [l, r] is the region now.
void udNode(int x, int val, int l, int r, int rt) {
if (l == r) {
Sum[rt] += val;
return;
}
int mid = (l + r) << 1;
// 确定x节点所在的区间
if (x <= mid) udNode(x, val, l, mid, rt << 1);
else udNode(x, val, mid + 1, r, rt << 1 | 1);
pushUp(rt);
}
/// ln is the number of the nodes of the Left Subtree
/// rn is the number of the nodes of the Right Subtree
// 下推lazy标记
void pushDown(int rt, int ln, int rn) {
if (Add[rt]) {
Add[rt << 1] += Add[rt];
Add[rt << 1 | 1] += Add[rt];
Sum[rt << 1] += Add[rt] * ln;
Sum[rt << 1 | 1] += Add[rt] * rn;
Add[rt] = 0;
}
}
/// [L, R] is the wanted region.
/// [l, r] is the region now.
// 这个范围很容易混,建议画图明确一下
void udRegion(int L, int R, int val, int l, int r, int rt) {
if (L <= l && r <= R) {
Sum[rt] += val * (r - l + 1);
Add[rt] += val;
return;
}
int mid = (l + r) >> 1;
// 下推lazy标记
pushDown(rt, mid - l + 1, r - mid);
if (L <= mid) udRegion(L, R, val, l, mid, rt << 1);
if (R > mid) udRegion(L, R, val, mid + 1, r, rt << 1 | 1);
pushUp(rt);
}
// 查询区间和
int Query(int L, int R, int l, int r, int rt) {
if (L <= l && r <= R) return Sum[rt];
int mid = (l + r) >> 1;
pushDown(rt, mid - l + 1, r - mid);
int ret = 0;
if (L <= mid) ret += Query(L, R, l, mid, rt << 1);
if (R > mid) ret += Query(L, R, mid + 1, r, rt << 1 | 1);
return ret;
}
int main() {
}