数据结构——线段树
一、不带懒标记的线段树:
我们用如下的数据结构表示点为1到10的线段:
每个线段维护一个性质:比如区间最大值、区间和、区间
g
c
d
gcd
gcd等。
线段树是一个满二叉树,所以我们可以用一维数组存储整棵树。那么对于编号是 X X X的节点,它的父亲节点为 X / 2 X/2 X/2,我们也可以写成 X > > 1 X>>1 X>>1;它的左儿子为 2 X 2X 2X,也可以写成 X < < 1 X<<1 X<<1;它的右儿子为 2 X + 1 2X+1 2X+1,也可以写成 X < < 1 ∣ 1 X<<1|1 X<<1∣1。
那么假设我们要维护的点数为 N N N,我们需要开辟多大的数组来存储对应的线段树呢?
答案是 4 N 4N 4N,因为如上图所示,倒数第二层我们假设都为叶子节点,一共有 N N N个,那么由于完全二叉树的性质,我们可以知道除最后一层外,树的节点个数共为 2 N − 1 2N-1 2N−1个节点,最后一层节点个数最坏情况下为倒数第二层的 2 2 2倍,也就是 2 N 2N 2N,所以总共的节点个数大概为 4 N − 1 4N-1 4N−1,所以我们建立线段树时要开辟 4 N 4N 4N个节点,对应就是 4 N 4N 4N大小的数组来存储线段树。
接下来介绍线段树的几个操作:
build
:建立线段树。pushup
:由两个子节点的性质推出父亲节点的性质。query
:区间查询,时间复杂度: O ( 4 l o g n ) O(4logn) O(4logn)。modify
:单点修改,时间复杂度: O ( 4 l o g n ) O(4logn) O(4logn)。
build函数:
void build (int u, int l, int r) {
tr[u].l=l. tr[u].r=r;
if (l==r) return;
int mid = l + r >> 1;
build(u << 1, l, mid), build(u << 1 | 1, mid+1, r);
pushup(u);
}
pushup函数:
void pushup(int u) {
当前节点的性质由两个儿子节点推出。
}
query函数:
设线段树的区间为 T l Tl Tl、 T r Tr Tr,查询的区间为 l l l、 r r r, m i d mid mid为当前线段树节点的中点:
- [ T l , T r ] [Tl,Tr] [Tl,Tr]属于 [ l , r ] [l,r] [l,r]直接返回。
- [ T l , T r ] [Tl,Tr] [Tl,Tr]与 [ l , r ] [l,r] [l,r]存在交集, l l l小于 m i d mid mid则递归左儿子, r r r大于 m i d mid mid则递归右儿子。
- [ T l , T r ] [Tl,Tr] [Tl,Tr]与 [ l , r ] [l,r] [l,r]不存在交集。(这种情况不存在)。
modify函数:
从上往下递归到目标点,再更改,再回溯且pushup
。
例题:AcWing 1275. 最大数
给定一个正整数数列。可以对这列数进行两种操作:
- 添加操作:向序列后添加一个数,序列长度变成 n + 1 n+1 n+1。
- 询问操作:询问这个序列中最后 L L L个数中最大的数是多少。
程序运行的最开始,整数序列为空,写一个程序,读入操作的序列,并输出询问操作的答案。
设有 m m m个操作,那么我们开辟 m m m个位置,每个位置初始位置为 0 0 0,用线段树维护,执行添加操作时,我们在指定位置上加一个数;执行询问操作时,我们查询区间长度的性质。单点修改,区间查询,线段树维护的性质为最大值。
#include <bits/stdc++.h>
using namespace std;
const int N = 200010;
struct node {
int l, r;
int maxn;
} tree[N << 2];
void pushup(int u) {
tree[u].maxn = max(tree[u << 1].maxn, tree[u << 1 | 1].maxn);
}
void build(int u, int l, int r) {
if (l == r) {
tree[u].l = tree[u].r = l;
tree[u].maxn = 0;
return;
}
tree[u].l = l, tree[u].r = r;
int mid = tree[u].l + tree[u].r >> 1;
build(u << 1, l, mid), build(u << 1 | 1, mid + 1, r);
pushup(u);
}
void modify(int u, int p, int x) {
if (tree[u].l == p && tree[u].r == p) {
tree[u].maxn = x;
return;
}
int mid = tree[u].l + tree[u].r >> 1;
if (p <= mid) modify(u << 1, p, x);
if (p > mid) modify(u << 1 | 1, p, x);
pushup(u);
}
int query(int u, int l, int r) {
if (tree[u].l >= l && tree[u].r <= r) {
return tree[u].maxn;
}
int mid = tree[u].l + tree[u].r >> 1;
int res = 0;
if (l <= mid) res = query(u << 1, l, r);
if (r > mid) res = max(res, query(u << 1 | 1, l, r));
return res;
}
int main() {
int m, p, cnt = 0, last = 0;
cin >> m >> p;
build(1, 1, m);
while (m--) {
char op[2];
int x;
cin >> op >> x;
if (*op == 'A') {
modify(1, ++cnt, ((x + last) % p + p) % p);
} else {
last = query(1, cnt - x + 1, cnt);
cout << last << endl;
}
}
return 0;
}
例题:AcWing 245. 你能回答这些问题吗
给定长度为 N N N的数列 A A A,以及 M M M条指令,每条指令可能是以下两种之一,对于每个查询指令,输出一个整数表示答案。
-
1 x y 1\ x\ y 1 x y,查询区间 [ x , y ] [x,y] [x,y]中的最大连续子段和,即 m a x x ≤ l ≤ r ≤ y ∑ i = l r A [ i ] max_{x≤l≤r≤y}{\sum _{i=l}^rA[i]} maxx≤l≤r≤y∑i=lrA[i]。
-
2 x y 2\ x\ y 2 x y,把 A [ x ] A[x] A[x]改成 y y y。
区间修改单点查询,显然我们用线段树就可以做,维护区间最大连续子段和 t m a x tmax tmax,那么我们思考一个节点的 t m a x tmax tmax能否由两个儿子节点的 t m a x tmax tmax得到,我们发现当前节点的 t m a x tmax tmax等于左儿子的 t m a x tmax tmax或者右儿子的 t m a x tmax tmax或者左儿子的最大连续后缀+右儿子的最大连续前缀。
因此我们还需要维护最大前缀 l m a x lmax lmax和最大后缀 r m a x rmax rmax,我们思考一个节点的 l m a x lmax lmax能否由两个儿子节点的性质得到,我们发现当前节点的 l m a x lmax lmax要么等于左儿子的 l m a x lmax lmax,要么等于左儿子的所有元素和+右儿子的 l m a x lmax lmax。 r m a x rmax rmax同理可得。
因此我们还需要维护节点所有元素之和 s u m sum sum,我们思考一个节点的 s u m sum sum能否由两个儿子节点的性质得到,显然我们发现当前节点的 s u m sum sum等于左儿子的 s u m sum sum加上右儿子的 s u m sum sum。
因此对于线段树的每个节点我们需要维护 t m a x 、 l m a x 、 r m a x 、 s u m tmax、lmax、rmax、sum tmax、lmax、rmax、sum。
口胡一时爽,调题火葬场。(bushi)
#include <bits/stdc++.h>
typedef long long ll;
using namespace std;
const int maxn = 5e5 + 10;
int a[maxn];
struct Node {
int l, r;
ll tmax, lmax, rmax, sum;
} tree[maxn << 2];
void pushup(Node &u, Node &l, Node &r) {
u.sum = l.sum + r.sum;
u.lmax = max(l.lmax, l.sum + r.lmax);
u.rmax = max(r.rmax, r.sum + l.rmax);
u.tmax = max({l.tmax, r.tmax, l.rmax + r.lmax});
}
void pushup(int u) {
pushup(tree[u], tree[u << 1], tree[u << 1 | 1]);
}
void build(int u, int l, int r) {
if (l == r) {
tree[u] = {l, r, a[l], a[l], a[l], a[l]};
return;
}
tree[u].l = l, tree[u].r = r;
int mid = l + r >> 1;
build(u << 1, l, mid), build(u << 1 | 1, mid + 1, r);
pushup(u);
}
void modify(int u, int x, int v) {
if (tree[u].l == x && tree[u].r == x) {
tree[u] = {x, x, v, v, v, v};
return;
}
int mid = tree[u].l + tree[u].r >> 1;
if (x <= mid) modify(u << 1, x, v);
if (x > mid) modify(u << 1 | 1, x, v);
pushup(u);
}
Node query(int u, int l, int r) {
if (tree[u].l >= l && tree[u].r <= r) return tree[u];
else {
int mid = tree[u].l + tree[u].r >> 1;
Node res, left, right;
if (r <= mid) return query(u << 1, l, r);
else if (l > mid) return query(u << 1 | 1, l, r);
else {
left = query(u << 1, l, r);
right = query(u << 1 | 1, l, r);
pushup(res, left, right);
return res;
}
}
}
int main() {
ios::sync_with_stdio(false);
cin.tie(0);
cout.tie(0);
int n, m;
cin >> n >> m;
for (int i = 1; i <= n; i++) cin >> a[i];
build(1, 1, n);
int x, y, z;
while (m--) {
cin >> x >> y >> z;
if (x == 1) {
if (y > z) swap(y, z);
cout << query(1, y, z).tmax << endl;
} else {
modify(1, y, z);
}
}
return 0;
}
例题:AcWing 246. 区间最大公约数
给定一个长度为 N N N的数列 A A A,以及 M M M条指令,每条指令可能是以下两种之一,对于每个询问,输出一个整数表示答案。
-
C l r d C\ l\ r\ d C l r d,表示把 A [ l ] , A [ l + 1 ] , … , A [ r ] A[l],A[l+1],…,A[r] A[l],A[l+1],…,A[r]都加上 d d d。
-
Q l r Q\ l\ r Q l r,表示询问 A [ l ] , A [ l + 1 ] , … , A [ r ] A[l],A[l+1],…,A[r] A[l],A[l+1],…,A[r]的最大公约数 g c d gcd gcd。
首先我们知道一个性质: g c d ( a , b , c , d . . . ) = g c d ( a , b − a , c − b , d − c . . . ) gcd(a, b, c, d...)=gcd(a, b-a, c-b, d-c...) gcd(a,b,c,d...)=gcd(a,b−a,c−b,d−c...)。我么可以用线段树维护差分,单点修改,区间查询。
线段树维护的性质为 g c d gcd gcd和当前节点元素和 s u m sum sum。这样执行 Q Q Q操作时,我们求前 l l l项的前缀和,再把结果与区间为 [ l + 1 , r ] [l+1, r] [l+1,r]的 g c d gcd gcd取 g c d gcd gcd;执行 C C C操作时,我们单点修改,把第 l l l项加上 d d d,把第 r + 1 r+1 r+1项减去 d d d。
口胡简单,但要调多久(好耶!我不学啦!)
#include <bits/stdc++.h>
using namespace std;
constexpr int maxn = 5e5 + 10;
typedef long long ll;
ll w[maxn];
ll gcd(ll x, ll y) {
return y ? gcd(y, x % y) : x;
}
struct Node {
int l, r;
ll sum, d;
} tr[maxn << 2];
void pushup(Node &u, Node &l, Node &r) {
u.sum = l.sum + r.sum;
u.d = gcd(l.d, r.d);
}
void pushup(int u) {
pushup(tr[u], tr[u << 1], tr[u << 1 | 1]);
}
void build(int u, int l, int r) {
if (l == r) {
tr[u].l = tr[u].r = l;
tr[u].sum = w[l] - w[l - 1];
tr[u].d = w[l] - w[l - 1];
return;
}
tr[u].l = l, tr[u].r = r;
int mid = l + r >> 1;
build(u << 1, l, mid), build(u << 1 | 1, mid + 1, r);
pushup(u);
}
void modify(int u, int x, ll v) {
if (tr[u].l == x && tr[u].r == x) {
tr[u] = {x, x, tr[u].sum + v, tr[u].sum + v};
return;
}
int mid = tr[u].l + tr[u].r >> 1;
if (x <= mid) modify(u << 1, x, v);
else modify(u << 1 | 1, x, v);
pushup(u);
}
Node query(int u, int l, int r) {
if (tr[u].l >= l && tr[u].r <= r) return tr[u];
else {
int mid = tr[u].l + tr[u].r >> 1;
if (r <= mid) return query(u << 1, l, r);
else if (l > mid) return query(u << 1 | 1, l, r);
else {
Node res;
auto left = query(u << 1, l, r);
auto right = query(u << 1 | 1, l, r);
pushup(res, left, right);
return res;
}
}
}
int main() {
ios::sync_with_stdio(false);
cin.tie(0);
cout.tie(0);
int n, m;
cin >> n >> m;
for (int i = 1; i <= n; i++) cin >> w[i];
build (1, 1, n);
while (m--) {
string s;
ll x, y, z;
cin >> s;
if (s[0] == 'Q') {
cin >> x >> y;
Node res = query(1, 1, x);
Node tmp({0, 0, 0, 0});
if (x + 1 <= y) tmp = query(1, x + 1, y);
cout << abs(gcd(res.sum, tmp.d)) << endl;
} else {
cin >> x >> y >> z;
modify(1, x, z);
if (y + 1 <= n) modify(1, y + 1, -z);
}
}
}
二、带懒标记的线段树:
pudown
:由父节点的性质来推出两个儿子节点的性质。modify
:区间修改,时间复杂度: O ( 4 l o g n ) O(4logn) O(4logn)。- 我们需要在
query
中分裂当前节点前加一个pushdown
,在modify
中分裂当前节点前加一个pushdown
;pushup
一般都会放在build
和modify
的最后。
例题:AcWing 243. 一个简单的整数问题2
给定一个长度为 N N N的数列 A A A,以及 M M M条指令,每条指令可能是以下两种之一,对于每个询问,输出一个整数表示答案。
- C l r d C\ l\ r\ d C l r d,表示把 A [ l ] , A [ l + 1 ] , … , A [ r ] A[l],A[l+1],…,A[r] A[l],A[l+1],…,A[r]都加上 d d d。
- Q l r Q\ l\ r Q l r,表示询问 数列中第 l l l到 r r r个数的和。
我们用线段树维护一个性质 s u m sum sum,并维护一个懒标记 a d d add add。 s u m sum sum代表如果考虑当前节点即子节点上的所有标记,当前区间和是多少(没有计算祖先节点上所有标记)。 a d d add add代表给当前区间的所有儿子加上 a d d add add(以当前节点为根节点的子树且不包含当前节点)。
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N = 100010;
int n, m;
int w[N];
struct Node {
int l, r;
ll sum, add;
}tr[N * 4];
void pushup(int u) {
tr[u].sum = tr[u << 1].sum + tr[u << 1 | 1].sum;
}
void pushdown(int u) {
auto &root = tr[u], &left = tr[u << 1], &right = tr[u << 1 | 1];
if (root.add) {
left.add += root.add, left.sum += (ll)(left.r - left.l + 1) * root.add;
right.add += root.add, right.sum += (ll)(right.r - right.l + 1) * root.add;
root.add = 0;
}
}
void build(int u, int l, int r) {
if (l == r) tr[u] = {l, r, w[r], 0};
else {
tr[u] = {l, r};
int mid = l + r >> 1;
build(u << 1, l, mid), build(u << 1 | 1, mid + 1, r);
pushup(u);
}
}
void modify(int u, int l, int r, int d) {
if (tr[u]. l >= l && tr[u].r <= r) {
tr[u].sum += (ll)(tr[u].r - tr[u].l + 1) * d;
tr[u].add += d;
} else {
pushdown(u);
int mid = tr[u].l + tr[u].r >> 1;
if (l <= mid) modify(u << 1, l, r, d);
if (r > mid) modify(u << 1 | 1, l, r, d);
pushup(u);
}
}
ll query(int u, int l, int r) {
if (tr[u].l >= l && tr[u].r <= r) return tr[u].sum;
pushdown(u);
int mid = tr[u].l + tr[u].r >> 1;
ll sum = 0;
if (l <= mid) sum = query(u << 1, l, r);
if (r > mid) sum += query(u << 1 | 1, l, r);
return sum;
}
int main() {
ios::sync_with_stdio(false);
cin.tie(0);
cout.tie(0);
cin >> n >> m;
for (int i = 1; i <= n; i++) cin >> w[i];
build(1, 1, n);
string op;
int l, r, d;
while (m--) {
cin >> op >> l >> r;
if (op[0] == 'C') {
cin >> d;
modify(1, l, r, d);
} else cout << query(1, l, r) << endl;
}
return 0;
}
例题:AcWing 1277. 维护序列
有长为 N N N的数列,不妨设为 a 1 , a 2 , … , a N a1,a2,…,aN a1,a2,…,aN。有如下三种操作形式:
- 把数列中的一段数全部乘一个值。
- 把数列中的一段数全部加一个值。
- 询问数列中的一段数的和,由于答案可能很大,你只需输出这个数模 P P P的值。
我们用线段树维护的性质为区间和 s u m sum sum,并附加两个懒标记 m u l mul mul(乘法), a d d add add(加法)。接下来我们考虑乘法和加法优先级的问题:
假设先算加法,那么 ( v + a d d ) ∗ m u l (v+add)*mul (v+add)∗mul,如果加上一个数,则为 ( v + a d d ) ∗ m u l + a d d 2 (v+add)*mul+add2 (v+add)∗mul+add2;如果乘上一个数,则为 ( v + a d d ) ∗ m u l ∗ m u l 2 (v+add)*mul*mul2 (v+add)∗mul∗mul2,我们发现加上一个数不容易写成 ( v + a d d ) ∗ m u l (v+add)*mul (v+add)∗mul这种形式。
假设先算乘法,那么 v ∗ m u l + a d d v*mul+add v∗mul+add,如果加上一个数,则为 v ∗ m u l + a d d + a d d 2 v*mul+add+add2 v∗mul+add+add2,如果乘上一个数,则为 ( v ∗ m u l + a d d ) ∗ m u l 2 = v ∗ m u l ∗ m u l 2 + a d d ∗ m u l 2 (v*mul+add)*mul2=v*mul*mul2+add*mul2 (v∗mul+add)∗mul2=v∗mul∗mul2+add∗mul2,其都可以转换成 v ∗ m u l + a d d v*mul+add v∗mul+add的形式,因此符合要求,并且可以看出对于先乘后加来说: ( v ∗ m u l + a d d ) ∗ m u l 2 + a d d 2 = v ∗ m u l ∗ m u l 2 + a d d ∗ m u l 2 + a d d 2 (v*mul+add)*mul2+add2=v*mul*mul2+add*mul2+add2 (v∗mul+add)∗mul2+add2=v∗mul∗mul2+add∗mul2+add2, a d d = a d d ∗ m u l 2 + a d d 2 , m u l = m u l ∗ m u l 2 add=add*mul2+add2, mul=mul*mul2 add=add∗mul2+add2,mul=mul∗mul2。
我们把 m u l mul mul的初始值设为 1 1 1, a d d add add的初始值设为 0 0 0,我们把第一个操作设为乘上一个值并且加上 0 0 0,我们把第二个操作设为乘上 1 1 1并且加上一个值。
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N = 100010;
int n, p, m;
int w[N];
struct Node {
int l, r;
int sum, add, mul;
} tr[N * 4];
void pushup(int u) {
tr[u].sum = (tr[u << 1].sum + tr[u << 1 | 1].sum) % p;
}
void eval(Node &t, int add, int mul) {
t.sum = ((ll)t.sum * mul + (ll)(t.r - t.l + 1) * add) % p;
t.mul = (ll)t.mul * mul % p;
t.add = ((ll)t.add * mul + add) % p;
}
void pushdown(int u) {
eval(tr[u << 1], tr[u].add, tr[u].mul);
eval(tr[u << 1 | 1], tr[u].add, tr[u].mul);
tr[u].add = 0, tr[u].mul = 1;
}
void build(int u, int l, int r) {
if (l == r)
tr[u] = {l, r, w[r], 0, 1};
else {
tr[u] = {l, r, 0, 0, 1};
int mid = l + r >> 1;
build(u << 1, l, mid), build(u << 1 | 1, mid + 1, r);
pushup(u);
}
}
void modify(int u, int l, int r, int add, int mul) {
if (tr[u].l >= l && tr[u].r <= r)
eval(tr[u], add, mul);
else {
pushdown(u);
int mid = tr[u].l + tr[u].r >> 1;
if (l <= mid) modify(u << 1, l, r, add, mul);
if (r > mid) modify(u << 1 | 1, l, r, add, mul);
pushup(u);
}
}
int query(int u, int l, int r) {
if (tr[u].l >= l && tr[u].r <= r) return tr[u].sum;
pushdown(u);
int mid = tr[u].l + tr[u].r >> 1;
int sum = 0;
if (l <= mid) sum = query(u << 1, l, r);
if (r > mid) sum = (sum + query(u << 1 | 1, l, r)) % p;
return sum;
}
int main() {
scanf("%d%d", &n, &p);
for (int i = 1; i <= n; i++) scanf("%d", &w[i]);
build(1, 1, n);
scanf("%d", &m);
while (m--) {
int t, l, r, d;
scanf("%d%d%d", &t, &l, &r);
if (t == 1) {
scanf("%d", &d);
modify(1, l, r, 0, d);
} else if (t == 2) {
scanf("%d", &d);
modify(1, l, r, d, 1);
} else
printf("%d\n", query(1, l, r));
}
return 0;
}