前言
线段树是算法竞赛中一个比较常用的数据结构,单独考查线段树可能并不会很难,但是结合其他算法就会变得比较困难。现在不必担心,根据我自己的经历,学好线段树只是时间问题,自己手动多敲过几周就能熟练运用了。
正片
线段树入门
线段树介绍
线段树是一颗二叉树,除了叶节点之外,每个节点有两个子节点。顾名思义,线段树里面应该有线段,那肯定有人会问这里的线段是指什么。在此处,线段是指一个数组中某个区间
[
l
,
r
]
[l, r]
[l,r],你可以把整个数组想象成一条线,那么其中的一段
[
l
,
r
]
[l,r]
[l,r]就是线段了。而线段树中,每个节点都会维护某个区间的信息,可以是区间之和、区间最大/小值等等。因此,线段树多用于解决区间问题。
讲了这么多,一定有人还啥都没懂。下面一张图,是将长度为7的线段建成一棵线段树:
可以看到,我们每次都将节点
[
l
,
r
]
[l, r]
[l,r]拆成两个子节点
[
l
,
m
i
d
]
[l, mid]
[l,mid]和
[
m
i
d
+
1
,
r
]
[mid + 1, r]
[mid+1,r],直到最后不能拆分为止,
即当前区间变成一个点,如
[
1
,
1
]
,
[
2
,
2
]
[1, 1],[2,2]
[1,1],[2,2]。
可是搞这么复杂又啥好处呢?用数组不香吗?搞得花里胡哨!
刚刚说了,线段树多用于解决区间问题,考虑下面问题。
单点修改,区间查询问题
给你 N N N个数 a i a_i ai, M M M次操作,现在有两种操作,①是修改 a p a_p ap的值为 v v v,②是查询区间 [ l , r ] [l, r] [l,r]的和。
暴力做法
如果我们用数组来写,
对于单点修改,直接将
a
p
a_p
ap改成
v
v
v即可,复杂度
O
(
1
)
O(1)
O(1),那么查询区间和呢?
对于区间查询,可以遍历一遍区间
[
l
,
r
]
[l,r]
[l,r]得到答案,复杂度
O
(
r
−
l
+
1
)
O(r-l+1)
O(r−l+1),如果是询问[1,N]就会变成
O
(
N
)
O(N)
O(N)。
那前缀和做差呢?这样只需要
O
(
1
)
O(1)
O(1)。但是前缀和得先维护好前缀和数组,当你修改了一个点
p
p
p的值的时候,你得将
[
p
,
n
]
[p,n]
[p,n]的前缀和数组全部更新一遍,这样单点修改就炸了。
线段树做法
于是这个时候,就需要用到我们的线段树了。
对于修改,直接在线段树中跑到对应的叶节点
[
p
,
p
]
[p, p]
[p,p],修改节点值即可,复杂度为
O
(
l
o
g
2
N
)
O(log_2N)
O(log2N)。
对于查询,在线段树里面找到各个子区间,相加就能得到答案。以上图为例,如果你要找
[
1
,
5
]
[1,5]
[1,5]的和,那么答案就是点2和点12的和,也就是
s
u
m
[
1
,
4
]
+
s
u
m
[
5
,
5
]
sum[1, 4] + sum[5, 5]
sum[1,4]+sum[5,5],复杂度也为
O
(
l
o
g
2
N
)
O(log_2N)
O(log2N)。
虽然修改的复杂度由
O
(
1
)
O(1)
O(1)->
O
(
l
o
g
2
N
)
O(log_2N)
O(log2N)反而增加,但查询的复杂度由
O
(
N
)
O(N)
O(N)->
O
(
l
o
g
2
N
)
O(log_2N)
O(log2N)下降了很多。
有得有失,不过综合来看,这样会好得多。
时间复杂度分析
当我们建立一棵线段树的时候,有
l
o
g
2
N
log_2N
log2N层
所以单点修改的时候,到达叶节点需要经过
l
o
g
2
N
log_2N
log2N层,复杂度就是
O
(
l
o
g
2
N
)
O(log_2N)
O(log2N)。
对于区间查询,每一个区间都能分成好几个小区间,对于每个小区间,我们查询至多
l
o
g
2
N
log_2N
log2N层,所以复杂度就是小区间个数
∗
l
o
g
2
N
*log_2N
∗log2N,由于小区间不多,忽略掉常数就是
O
(
l
o
g
2
N
)
O(log_2N)
O(log2N)。
实现
首先考虑每个节点记录什么。
(1)是代表哪个区间
(2)记录区间之和
这里我用结构体来表示节点(当然可以用数组,看你习惯)
struct SegmentTree {
int l, r, sum;
} t[MAX << 2];//MAX << 2 == MAX * 4, 位运算
那么这里为什么空间要开四倍呢?
我一开始以为是两倍,但后来发现错了…
我的想法:
一颗线段树有
l
o
g
2
N
log_2N
log2N层,第一层有1个,第二层有2个,第三层有4个…最后一层有
2
l
o
g
2
N
2^{log_2N}
2log2N个,所以总共有
∑
i
=
0
l
o
g
2
N
2
i
=
2
l
o
g
2
N
+
1
−
1
=
2
l
o
g
2
N
∗
2
−
1
=
2
N
−
1
\displaystyle\sum_{i=0}^{log_2N}2^i=2^{log_2N+1}-1=2^{log_2N}*2-1=2N-1
i=0∑log2N2i=2log2N+1−1=2log2N∗2−1=2N−1 个节点
这样来看节点个数确实是
2
N
−
1
2N-1
2N−1个
这里我们是以 当前节点在数组中下标为
u
u
u,那么左儿子为
2
u
2u
2u,右儿子为
2
u
+
1
2u+1
2u+1的方法来建树的,所以可能会有点乘2之后超过
2
N
2N
2N,会造成数组越界,所以空间开到4倍,但是实际上只用了2倍的空间。
可以参考这个线段树为什么要开4倍空间
接下来看操作
在这里只需要三个操作:
①建树
②单点修改
③区间查询
这里我根据我自己的个人习惯,我加了几个define,因为写代码好写
加了之后你的左子树就是
t
[
l
c
]
t[lc]
t[lc],右子树就是
t
[
r
c
]
t[rc]
t[rc],比起原来的
t
[
u
<
<
1
]
,
t
[
u
<
<
1
∣
1
]
t[u<<1], t[u<<1|1]
t[u<<1],t[u<<1∣1]要简洁得多
至于为啥是左子树的下标是
2
u
2u
2u,右子树的下标是
2
u
+
1
2u+1
2u+1,因为线段树也是二叉树,所以可以用这种方式来直接建树,并且不会出现冲突。可以自己手动建一颗试一试。
//lc为左子树,lc -> leftChild
//rc为右子树, rc -> rightChild
//m为(l+r)/2, m -> mid
#define lc u<<1
#define rc u<<1|1
#define m (l+r)/2
#define mid (t[u].l+t[u].r)/2
建树
void build(int u, int l, int r) {//u为当前节点,当前区间为[l, r]
t[u].l = l, t[u].r = r;
if (l == r) {//到达叶节点
t[u].sum = a[l];//因为当前的区间端点 l 表示的是数组中的位置
//所以当前节点的值就是a[l]
return;
}
build(lc, l, m);//建立左子树
build(rc, m + 1, r);//建立右子树
t[u].sum = t[lc].sum + t[rc].sum;//更新当前点的和
//当前区间[l, r]的和可以由他的子区间[l, m]和[m + 1, r]得到
}
单点修改
void update(int u, int p, int v) {
if (t[u].l == t[u].r) {//当前点为叶节点
t[u].sum = v;
return;
}
int mid = (t[u].l + t[u].r) / 2;//左子树[l, mid], 右子树[mid + 1, r]
if (p <= mid) update(lc, p, v);//如果当前点在左子树中,那就进左子树
else update(rc, p, v);//当前点在右子树中国
t[u].sum = t[lc].sum + t[rc].sum;//更新完值之后要记得更新区间和
}
区间查询
int query(int u, int ql, int qr) {
if (ql <= t[u].l && t[u].r <= qr) //当前区间完全被查询的区间所包含
return t[u].sum;//直接返回当前区间和
int mid = (t[u].l + t[u].r) / 2;
int res = 0;//记录查询答案
if (ql <= mid) res += query(lc, ql, qr);
//左子树[l, mid]交[ql, qr]非空 -> 有交集,条件就是ql <= mid
if (qr > mid) res += query(rc, ql, qr);
//右子树[mid + 1, r]交[ql, qr]非空 -> 有交集,条件就是qr >= mid + 1, 也就是qr > mid
//因为也没有更新节点的值,所以不用这句:t[u].sum = t[lc].sum + t[rc].sum
return res;//最后返回答案
}
搞清楚这个简单的问题后,我们可以再进一步,将单点修改变成区间修改,学习带标记下传的线段树。
区间修改,区间查询问题(点我进入此题):
给你 N N N个数 a i a_i ai, M M M此操作,现在有两种操作,①是区间 [ l , r ] [l, r] [l,r]所有的值都加上 k k k,②是查询区间 [ l , r ] [l, r] [l,r]的和。
分析做法
首先,如果你和之前一样单点修改,那么你需要修改
r
−
l
+
1
r-l+1
r−l+1个点,复杂度会到达
N
l
o
g
2
N
Nlog_2N
Nlog2N,显然这不是我们想要的,所以我们需要优化。
那么能不能类似之前区间查询一样进行修改呢?当然可以。
显然,如果我们要修改某一区间 [ l , r ] [l,r] [l,r]的值,我们只需要修改到它的小区间为止,而不用修改到每一个叶节点,否则就炸了。但是这里会有问题,以修改区间 [ 1 , 2 ] [1,2] [1,2]为例。
我们的想法是修改到 [ 1 , 2 ] [1,2] [1,2],这样修改就比较快,但是如果你下一次要查询 [ 1 , 1 ] [1,1] [1,1]咋办,上次你只修改到 [ 1 , 2 ] [1,2] [1,2],但是 [ 1 , 1 ] [1,1] [1,1]节点的值还是原来的,所以这里我们引入一个新的东西——延时标记 t a g tag tag。
延时标记有啥用呢?比如当你修改到 [ 1 , 2 ] [1,2] [1,2],你可以给这个节点上一个标记,就记为要修改的值,当你下次查询 [ 1 , 1 ] [1,1] [1,1]时,由于你必然会经过 [ 1 , 2 ] [1,2] [1,2]节点,因此在这个时候,你可以将标记下传(push_down),即之前更新的值往下传递,这样 [ 1 , 1 ] [1,1] [1,1]和 [ 2 , 2 ] [2,2] [2,2]就都修改完,这个时候你在查,就不会出问题,这样我们就成功将复杂度降为 O ( l o g 2 N ) O(log_2N) O(log2N),此处复杂度和之前区间查询一样。
说白了,延时标记就是,你不用查的时候我就不下传,只有要往下查的时候,我才会更新,用就传递,不用就放那里。这样大大减少了我们的修改次数。
实现
这里与之前相比没有什么太大,都写在代码当中。
#include<bits/stdc++.h>
using namespace std;
#define m (l+r)/2
#define mid (t[u].l+t[u].r)/2
#define lc u<<1
#define rc u<<1|1
typedef long long ll;
const int MAX = 1e5 + 10;
int N, M;
ll a[MAX];
struct SegmentTree {
int l, r;
ll sum, tag;
void upd(ll v) {
sum += 1ll * (r - l + 1) * v;//当前区间之和加上 区间元素个数 * v
tag += v;//标记该点
}
} t[MAX << 2];
void push_up(int u) {//推荐这么写, 因为有时候需要push_up的东西很多, 写成函数简洁
t[u].sum = t[lc].sum + t[rc].sum;
}
void push_down(int u) {
if (t[u].tag) {//如果有标记, 就标记下传
t[lc].upd(t[u].tag);//传给左子树
t[rc].upd(t[u].tag);//传给右子树
t[u].tag = 0;//标记清空
}
}
void build(int u, int l, int r) {
t[u].l = l, t[u].r = r, t[u].tag = 0;
if (l == r) {
t[u].sum = a[l];
return;
}
build(lc, l, m); build(rc, m + 1, r);
push_up(u);
}
void update(int u, int ql, int qr, ll v) {
if (ql <= t[u].l && t[u].r <= qr) {
t[u].upd(v);
return;
}
//到达此处, 说明没有被完全包含, 需要访问左子树或者右子树
push_down(u);//所以如果有标记要将标记下传
if (ql <= mid) update(lc, ql, qr, v);
if (qr > mid) update(rc, ql, qr, v);
push_up(u);//更新当前点
}
ll query(int u, int ql, int qr) {
if (ql <= t[u].l && t[u].r <= qr) return t[u].sum;
push_down(u);//与update同理
ll res = 0;
if (ql <= mid) res += query(lc, ql, qr);
if (qr > mid) res += query(rc, ql, qr);
//为什么这里不用push_up呢, 因为查询过程中只会对下面的点更新
//而上面的节点一定比下面的点更新的早,所以这里早就更新完了
//如果你整不明白那还是加上一句push_up(u)
//push_up(u);
return res;
}
int main() {
ios::sync_with_stdio(0); cin.tie(0); cout.tie(0);
cin >> N >> M;
for (int i = 1; i <= N; i++) cin >> a[i];
build(1, 1, N);
while (M--) {
int op, ql, qr;
cin >> op >> ql >> qr;
if (op == 1) {
ll k; cin >> k;
update(1, ql, qr, k);
}
else cout << query(1, ql, qr) << endl;
}
return 0;
}
如果你已经学会了,那么就练几个题来巩固吧!
练习题
P3870 [TJOI2009]开关
给你一串只有01的串,可以反转区间
[
l
,
r
]
[l, r]
[l,r],查询区间内有多少个1。
P3373【模板】线段树 2
此题为区间乘法+加法线段树,需要两个标记,一个乘法,一个加法。
做的过程中注意标记下传时,乘法加法顺序。
P1471 方差
化简一下方差式子,就会发现只需要多维护一个平方和。
P4145 上帝造题的七分钟2 / 花神游历各国
这题比较特殊,区间开根号向下取整,只要区间最大值为1就不需要在更新,维护最大值,每次更新区间再来一个判断。