文章目录
线段树
高级数据结构。
关于线段树支持的操作:
1.单点修改
2.区间修改(懒标记)
3.区间查询
一.关于单点修改,区间查询的线段树(不需要懒标记)
例题:
题目描述:
给定 n 个数组成的一个数列,规定有两种操作,一是修改某个元素,二是求子数列 [a,b] 的连续和。
输入格式
第一行包含两个整数 n 和 m,分别表示数的个数和操作次数。
第二行包含 n 个整数,表示完整数列。
接下来 m 行,每行包含三个整数 k,a,b (k=0,表示求子数列[a,b]的和;k=1,表示第 a 个数加 b)。
数列从 1 开始计数。
输出格式
输出若干行数字,表示 k=0 时,对应的子数列 [a,b] 的连续和。
数据范围
1≤n≤100000,
1≤m≤100000,
1≤a≤b≤n,
数据保证在任何时候,数列中所有元素之和均在 int 范围内。
输入样例:
10 5
1 2 3 4 5 6 7 8 9 10
1 1 5
0 1 3
0 4 8
1 7 5
0 4 8
输出样例:
11
30
35
题解:
题目要求点单修改和区间查询,如果使用前缀和,我们需要维护两个数组,一个原数组,一个前缀和数组,进行单点修改时就在原数组中修改,时间复杂度是O(1),但每次单点修改后,需要更新一遍前缀和数组,时间复杂度是O(n),区间查询的复杂度为O(n) , 这样的话,如果n次操作中大多数都是单点修改,那么总的时间复杂度就来到了O(n ^ 2)。显然会TLE。
线段树就可以很好地解决上述问题,线段树可以将每次单点修改的复杂度变成O(logn),把区间查询的复杂度变成O(logn),这样n次操作的复杂度就可以降低为O(nlogn)。
那么如何维护一颗线段树:
首先,线段树是一种二叉树结构,我们考虑如何存储一颗线段树
常见的线段树有两种存储方式:
一种是数组(打比赛常用,这里只讲数组实现方式),一种是链表。
那么用数组如何做呢,这里的存储方式类似于堆的数组实现方式。
我们开一个结构体数组,这个结构体数组需要记录区间左右端点,以及这个区间的和。
struct node{
int l , r;
int sum;
}tr[N << 2]; //数组要开4倍
接着,我们需要初始化一颗线段树。这里需要用到递归和回溯的处理方式。
我们分别递归创建这个结点的左右儿子 直到叶子节点,当递归到叶子结点时,每个结点的区间和就是原来数组中每个数,在处理完叶子节点后的回溯过程中由于已经处理出了左右儿子的区间和,那么此时结点的区间和就是左右儿子的区间和之和。
void build(int k , int l , int r){
tr[k] = {l , r};
if (l == r){
tr[k].sum = a[l];
return ;
}
int mid = l + r >> 1;
build(lc , l , mid);
build(rc , mid + 1 , r);
tr[k].sum = tr[lc].sum + tr[rc].sum;
}
单点修改
定义一个add函数来进行单点修改,我们从根结点开始查询我们要修改的点在哪个区间,相信聪明的读者这一步把无需我多解释。
需要多提一句的是,回溯时记得更新当前节点的值。
void add(int k , int pos , int delta){
if (tr[k].l == pos && tr[k].r == pos){
tr[k].sum += delta;
return;
}
int mid = tr[k].l + tr[k].r >> 1;
if (pos <= mid) add(lc , pos , delta);
else if (pos > mid) add(rc , pos , delta);
tr[k].sum = tr[lc].sum + tr[rc].sum;
}
区间查询
思想和单点修改一样,但是这里是:当查询区间包含当前结点的区间时,说明是我们需要的区间。
如果当前区间包含查询区间,说明区间大了,不是我们需要的区间,递归处理左子树和右子树。
int query(int k , int l , int r){
if (tr[k].l >= l && tr[k].r <= r){
return tr[k].sum;
}
int s = 0;
int mid = tr[k].l + tr[k].r >> 1;
if (l <= mid) s += query(lc , l , r);
if (r > mid) s += query(rc , l , r);
return s;
}
完整代码参考
#include <iostream>
#define lc k << 1
#define rc k << 1 | 1
using namespace std;
const int N = 1e5 + 5;
int n , m;
int a[N];
struct node{
int l , r;
int sum;
}tr[N << 2];
void build(int k , int l , int r){
tr[k] = {l , r};
if (l == r){
tr[k].sum = a[l];
return ;
}
int mid = l + r >> 1;
build(lc , l , mid);
build(rc , mid + 1 , r);
tr[k].sum = tr[lc].sum + tr[rc].sum;
}
int query(int k , int l , int r){
if (tr[k].l >= l && tr[k].r <= r){
return tr[k].sum;
}
int s = 0;
int mid = tr[k].l + tr[k].r >> 1;
if (l <= mid) s += query(lc , l , r);
if (r > mid) s += query(rc , l , r);
return s;
}
void add(int k , int pos , int delta){
if (tr[k].l == pos && tr[k].r == pos){
tr[k].sum += delta;
return;
}
int mid = tr[k].l + tr[k].r >> 1;
if (pos <= mid) add(lc , pos , delta);
else if (pos > mid) add(rc , pos , delta);
tr[k].sum = tr[lc].sum + tr[rc].sum;
}
int main(){
scanf ("%d %d" , &n , &m);
for (int i = 1 ; i <= n ; ++ i) scanf ("%d" , &a[i]);
build(1 , 1 , n);
for (int i = 0 ; i < m ; ++ i){
int op , a , b;
scanf ("%d" , &op);
if (op == 0){
scanf("%d %d" , &a , &b);
printf("%d\n" , query(1 , a , b));
}
else{
scanf("%d %d" , &a , &b);
add(1 , a , b);
}
}
return 0;
}
二. 线段树(带懒标记,但不下放)
说明
这个版本的线段树比较好理解,标记打到区间覆盖就可以了
在查询某个区间时,这个区间的值由两部分构成:这个区间本来的值 + (这个区间的长度)*(从该结点到根节点上所有的标记和).
例题
#include <iostream>
#define lc k << 1
#define rc k << 1 | 1
using namespace std;
typedef long long ll;
const int N = 1e5 + 5;
int n , m;
ll a[N];
struct node{
int l , r;
ll lazy;
ll sum;
}tr[N << 2];
void build(int k , int l , int r){
tr[k] = {l , r};
if (l == r){
tr[k].sum = a[l];
return;
}
int mid = l + r >> 1;
build(lc , l , mid);
build(rc , mid + 1 , r);
tr[k].sum = tr[lc].sum + tr[rc].sum;
}
void add(int k , int l , int r , ll delta){
if (tr[k].l == l && tr[k].r == r){ //用区间相等的方法
tr[k].lazy += delta;
return;
}
tr[k].sum = tr[k].sum + (r - l + 1) * delta;
int mid = tr[k].l + tr[k].r >> 1;
if (r <= mid) add(lc , l , r , delta);
else if(l > mid) add(rc , l , r , delta);
else{
add(lc , l , mid , delta);
add(rc , mid + 1 , r , delta);
}
}
/*
void add(int k , int l , int r , ll delta){
if (tr[k].l >= l && tr[k].r <= r){ //用区间覆盖的方法,,不推荐
tr[k].lazy += delta;
return ;
}
int mid = tr[k].l + tr[k].r >> 1;
if (l <= mid){
tr[k].sum = tr[k].sum + (min(mid , r) - max(l , tr[k].l) + 1) * delta;
add(lc , l , r , delta);
}
if (r > mid){
tr[k].sum = tr[k].sum + (min(tr[k].r , r) - max(l , (mid + 1)) + 1) * delta;
add(rc , l , r , delta);
}
}
*/
ll query(int k , int l , int r , ll acc){
acc += tr[k].lazy;
if (tr[k].l == l && tr[k].r == r){
return tr[k].sum + (tr[k].r - tr[k].l + 1) * acc;
}
int mid = tr[k].l + tr[k].r >> 1;
ll s = 0;
if(r <= mid) s = query(lc , l , r , acc);
else if (l > mid) s = query(rc , l , r , acc);
else s = query(lc , l , mid , acc) + query(rc , mid + 1 , r , acc);
return s;
}
/*
第二种query方法
ll query(int k , int l , int r , ll acc){
acc += tr[k].lazy;
if(tr[k].l >= l && tr[k].r <= r){
return tr[k].sum + (tr[k].r - tr[k].l + 1) * acc;
}
int mid = tr[k].l + tr[k].r >> 1;
ll s = 0;
if (l <= mid) s += query(lc , l , r , acc);
if (r > mid) s += query(rc , l , r , acc);
return s;
}
*/
int main(){
scanf ("%d %d" , &n , &m);
for (int i = 1 ; i <= n ; ++ i) scanf ("%lld" , &a[i]);
build(1 , 1 , n);
for(int i = 0 ; i < m ; ++ i){
int op , x , y;
ll k;
scanf ("%d" , &op);
if (op == 1){
scanf("%d %d %ld" , &x , &y , &k);
add(1 , x , y , k);
}
else{
scanf ("%d %d" , &x , &y);
printf("%lld\n" , query(1 , x , y , 0));
}
}
return 0;
}
三.线段树(懒标记下放版1)
说明
考虑到最后查询的方便,只改变对应标记的值
即:在乘,加,查询的操作过程中不改变原来的区间值
只在回溯过程中更新父节点的区间和。
这里还有一个问题要考虑:
当我们找到查询区间时 , 我们知道这个区间本来的值 , 和在这个区间上的懒标记的值,但是我们怎么知道是先加再乘 还是 先乘后加呢?
先来看这样一个问题:
对于一个x , 我们对x作如下操作:
(
x
+
2
)
∗
2
+
2
(x + 2) * 2 + 2
(x+2)∗2+2
我们把操作展开:
2
∗
x
+
2
∗
2
+
2
2*x + 2*2 + 2
2∗x+2∗2+2
所以只要我们在对乘标记作操作时,顺带把加标记一起乘,最后就可以把表达式转换成多项式相加就可以了。
例题
#include <iostream>
using namespace std;
typedef long long ll;
const int N = 1e5 + 5;
int n , q;
ll m;
ll a[N];
struct node{
int l , r;
ll sum;
ll tag1 , tag2;
}tr[N << 2];
void build(int k , int l , int r){
tr[k] = {l , r};
tr[k].tag2 = 1;
if (l == r){
tr[k].sum = a[l] % m;
return ;
}
int mid = l + r >> 1;
build(k << 1 , l , mid);
build(k << 1 | 1 , mid + 1 , r);
tr[k].sum = (tr[k << 1].sum + tr[k << 1 | 1].sum) % m;
}
void pushdown(int k){
tr[k << 1].tag1 = (tr[k << 1].tag1 * tr[k].tag2 % m + tr[k].tag1) % m;
tr[k << 1].tag2 = tr[k << 1].tag2 * tr[k].tag2 % m;
tr[k << 1 | 1].tag1 = (tr[k << 1 | 1].tag1 * tr[k].tag2 % m + tr[k].tag1) % m;
tr[k << 1 | 1].tag2 = tr[k << 1 | 1].tag2 * tr[k].tag2 % m;
tr[k].tag1 = 0;
tr[k].tag2 = 1;
}
void pushup(int k){
tr[k].sum = (tr[k << 1].sum * tr[k << 1].tag2 +
(tr[k << 1].r - tr[k << 1].l + 1) * tr[k << 1].tag1 % m) % m
+ (tr[k << 1 | 1].sum * tr[k << 1 | 1].tag2
+ (tr[k << 1 | 1].r - tr[k << 1 | 1].l + 1) * tr[k << 1 | 1].tag1 % m) % m;
tr[k].sum %= m;
}
void mul(int k , int l , int r , ll delta){
if (tr[k].l == l && tr[k].r == r){
tr[k].tag1 = tr[k].tag1 * delta % m;
tr[k].tag2 = tr[k].tag2 * delta % m;
return ;
}
pushdown(k);
int mid = tr[k].l + tr[k].r >> 1;
if (r <= mid) mul(k << 1 , l , r , delta);
else if (l > mid) mul(k << 1 | 1 , l , r , delta);
else mul(k << 1 , l , mid , delta) , mul(k << 1 | 1 , mid + 1 , r , delta);
pushup(k);
}
void add(int k , int l , int r , ll delta){
if (tr[k].l == l && tr[k].r == r){
tr[k].tag1 = (tr[k].tag1 + delta) % m;
return ;
}
pushdown(k);
int mid = tr[k].l + tr[k].r >> 1;
if (r <= mid) add(k << 1 , l , r , delta);
else if (l > mid) add(k << 1 | 1 , l , r , delta);
else add(k << 1 , l , mid , delta) , add(k << 1 | 1 , mid + 1 , r , delta);
pushup(k);
}
ll query(int k , int l , int r){
ll s = 0;
if (tr[k].l == l && tr[k].r == r){
return (tr[k].sum * tr[k].tag2 + (tr[k].r - tr[k].l + 1) * tr[k].tag1 % m) % m;
}
pushdown(k);
int mid = tr[k].l + tr[k].r >> 1;
if (r <= mid) s = query(k << 1 , l , r);
else if (l > mid) s = query(k << 1 | 1 , l , r);
else s = query(k << 1 , l , mid) + query(k << 1 | 1 , mid + 1 , r);
pushup(k);
return s;
}
int main(){
scanf ("%d %d %d" , &n , &q , &m);
for (int i = 1 ; i <= n ; ++ i) scanf ("%lld" , &a[i]);
build(1 , 1 , n);
for (int i = 0 ; i < q ; ++ i){
int op , x , y;
ll k;
scanf ("%d" , &op);
if (op == 1){
scanf ("%d %d %lld" , &x , &y , &k);
mul(1 , x , y , k);
}
else if (op == 2){
scanf ("%d %d %lld" , &x , &y , &k);
add(1 , x , y , k);
}
else{
scanf ("%d %d" , &x , &y);
printf("%lld\n" , query(1 , x , y) % m);
}
}
return 0;
}
四. 线段树(懒标记下放版2)
说明
在下放标记的过程中同时更新该节点左右儿子结点的值,同时清空该节点的标记。
例题
以线段树1为例
#include <iostream>
#define lc k << 1
#define rc k << 1 | 1
using namespace std;
typedef long long ll;
const int N = 1e5 + 5;
int n , m;
ll a[N];
struct node{
int l , r;
ll fg;
ll sum;
}tr[N << 2];
void build(int k , int l , int r){
tr[k] = {l , r};
if (l == r){
tr[k].sum = a[l];
return ;
}
int mid = l + r >> 1;
build(lc , l , mid);
build(rc , mid + 1 , r);
tr[k].sum = tr[lc].sum + tr[rc].sum;
}
void spread(int k){
tr[lc].sum = tr[lc].sum + (tr[lc].r - tr[lc].l + 1) * tr[k].fg;
tr[rc].sum = tr[rc].sum + (tr[rc].r - tr[rc].l + 1) * tr[k].fg;
tr[lc].fg += tr[k].fg;
tr[rc].fg += tr[k].fg;
tr[k].fg = 0;
}
void add(int k , int l , int r , int delta){
if (tr[k].l == l && tr[k].r == r){
tr[k].sum = tr[k].sum + (r - l + 1) * delta;
tr[k].fg += delta;
return ;
}
spread(k);
int mid = tr[k].l + tr[k].r >> 1;
if (r <= mid) add(lc , l , r , delta);
else if (l > mid) add(rc , l , r , delta);
else{
add(lc , l , mid , delta);
add(rc , mid + 1 , r , delta);
}
tr[k].sum = tr[lc].sum + tr[rc].sum;
}
ll query(int k , int l , int r){
if (tr[k].l == l && tr[k].r == r){
return tr[k].sum;
}
spread(k);
ll s = 0;
int mid = tr[k].l + tr[k].r >> 1;
if (r <= mid) s = query(lc , l , r);
else if (l > mid) s = query(rc , l , r);
else s = query(lc , l , mid) + query(rc , mid + 1 , r);
tr[k].sum = tr[lc].sum + tr[rc].sum;
return s;
}
int main(){
scanf ("%d %d" , &n , &m);
for (int i = 1 ; i <= n ; ++ i) scanf ("%lld" , &a[i]);
build(1 , 1 , n);
for (int i = 0 ; i < m ; ++ i){
int op , x , y;
ll k;
scanf ("%d" , &op);
if (op == 1){
scanf ("%d %d %lld" , &x , &y , &k);
add(1 , x , y , k);
}
else{
scanf ("%d %d" , &x , &y);
printf("%lld\n" , query(1 , x , y));
}
}
return 0;
}