参考资料:《进阶指南》
线段树是一种基于分治思想的二叉树结构,用于在区间上进行信息统计。
- 线段树每一个节点代表一个区间
- 线段树具有唯一的一个根节点,代表的区间是整个统计范围,[1,N]
- 线段树的叶节点都代表一个长度为1的元区间[x,x]
- 对于每个内部节点[l,r],左节点是[l,mid],右节点是[mid+1,r],其中mid=(l+r)/2;
根节点编号为1,编号为x的节点的左节点为x*2,右节点为x*2+1;
当我们用Struct来存线段树时,数据要开四倍,证明如下:
如上图,不是一颗完全二叉树,我们设倒数第二层(就是满节点的那层节点数为N),则从这一层开始到根的总节点数是N+N/2+N/4+……+2+1=2N-1;
并且最后一层如果是满的情况下,节点数是2N,所以加起来一共是4N,所以至少开4倍空间。
结构体代码:
struct SeamengTree
{
int l,r;
int dat;
}t[maxn*4];//struct存线段树
我们现在来建树。给定一个长度为N的序列A,我们在区间【1,N】上建立线段树。每个节点存下相应的信息(根据题目定,但是要能满足区间可加性)
比如每个节点存最大值dat(l,r);dat(l,r)=max(dat(l,mid),dat(mid+1,r));
图源:秦淮岸灯火阑珊(已退役)
在建树的时候采用递归去建立,递归边界,当递归到叶子节点时,也就是[x,x],其值就是一个点的值,也就是我们输入的元素的值。
不然的话就先建立左子树,递归到叶节点,然后回溯建立对应右子树,再回溯;【PS:这个过程和dfs回溯递归是有些类似的】【自己写递归的时候思考最大的那一层比较好写。毕竟我是蒟蒻】
比如这里要保存每个节点(一个节点代表一个区间)的最大值,那么在回溯的时候从下往上也就是父亲节点传递信息
void build(int p,int l,int r)//最开始传入(1,1,n)
{
t[p].l=l;t[p].r=r;
if(l==r) {t[p].dat=a[i];return;}
int mid=(l+r)/2;
build(p*2,l,mid);
build(p*2,mid+1,r);
t[p].dat=max(t[p*2].dat,t[p*2+1].dat);
}
在最开始的时候是单点修改(以后区间修改可以直接写单点修改,不过是修改区间是一个点[x,x]同时不用lazy就是了)
当我们修改一个点的时候,我们要递归到最底层也就是叶子节点去修改叶子节点的信息值,比如修改一个值,我们现在节点存的是区间最大值,回溯的时候就要更新信息。
void change(int p,int x,int v)//调用(1,7,1) p是起初是根节点,x是要改的单点
{
if(t[p].l==t[p].r){t[p].dat=v;return;}//说明找到了要改的单点区间
int mid=(t[p].l+t[p].r)/2;
if(x<=mid) change(p*2,x,v);//说明要找的x在p点的左半区间
else change(p*2+1,x,v);
t[p].dat=max(t[p*2].dat,t[p*2+1].dat);//修改后更新父亲的保存信息
}
线段树的区间查询
比如说我们现在要找序列A在区间[l,r]上的最大值,我们从根节点开始,递归去找:
- 如果要找的[l,r]完全包含了当前节点代表的区间,直接返回就好了(比如根节点代表[1,10],完全覆盖的情况也就是找区间[l=1,r=10]//[l=0,r=11]的最大值,所以直接返回信息就好了)
- 如果要找的[l,r]与左子节点有重叠部分,就递归找访问左子节点[重叠的那部分访问下去,迟早找到完全覆盖的区间,直接返回]
- 如果要找的[l,r]与右子节点有重叠部分,就递归找访问右子节点[ 重叠的那部分访问下去,迟早找到完全覆盖的区间,直接返回]
比如根节点代表的区间是[1,10],现在要查询[l=2,r=8]的区间最大值
过程为:先递归访问左节点[1,5]在里面找[2,8];
再进去在[1,3]里面找[2,8];
再进去在[1,2]里面找[2,8];
再进去在[2,2]里面找[2,8];此刻完全覆盖找到,返回节点的dat;
回溯到[3,3]里面找[2,8];完全覆盖,返回节点的dat;
接着找[1,5]右子节点的[2,8],完全覆盖,返回节点的dat;
回去找[1.10]右子节点的[6,10]的[2,8];找[6.10]的左节点[6,8],完全覆盖,返回子节点的dat
int ask(int p,int l,int r)
{
if(l<=t[p].l&&r>=t[p].r) return t[p].dat;
int mid=(t[p].l+t[p].r)/2;
int val=-(1<<30);//负无穷大
if(l<=mid) val=max(val,ask(p*2,l,r));//左子节点有重叠
if(r>mid) val=max(val,ask(p*2,l,r));//右子节点有重叠
return val;
}
查询区间[l,r]在线段树上分成O(logn)个节点,取他们的最大值作为答案。时间复杂度O(log(N));和ST表比起来,能动态处理区间最值(RMQ)问题了。
lazy标记的诞生环境:当我们要区间修改的时候,复杂度会发生变化,原来单点修改是O(log(N)),当变成一个区间时,修改的复杂度为O(NlogN),lazy标记就是让复杂度降低为logN。
比如说我要改一个区间[l,r],使其内的每一个点都加d;假如这个点是完全覆盖的,那我们在这个节点额外添加一个信息lazy,当后面要查询的区间也是这个完全覆盖的区间的时候就可以直接返回了
那假如要改的区间[l,r]没有完全覆盖呢?
和我们之前的操作一样,递归下去一定能找到完全覆盖的,然后打上lazy标记。
比如在第三层的某一个节点x代表的区间是完全覆盖的,那我们就标记lazy。假如后面的区间修改的区间在这个节点x下面,那我们就消去这个节点的lazy,找到这个x节点下面的完全覆盖的区间进行打lazy标记就好啦。
或者要找的时候在这个节点x代表的区间下面,我们也要把这个节点的lazy消去,对这个节点下面完全覆盖的区间进行lazy标记
同时注意一下,打的标记的数就是加减的数,每一个节点的增减量就是(t[p].r-t[p].l+1)*x(因为是区间里每一个点都同加减一个数)
题目描述
如题,已知一个数列,你需要进行下面两种操作:
- 将某区间每一个数加上 kk。
- 求出某区间每一个数的和。
输入格式
第一行包含两个整数 n, mn,m,分别表示该数列数字的个数和操作的总个数。
第二行包含 nn 个用空格分隔的整数,其中第 ii 个数字表示数列第 ii 项的初始值。
接下来 mm 行每行包含 33 或 44 个整数,表示一个操作,具体如下:
1 x y k
:将区间 [x, y][x,y] 内每个数加上 kk。2 x y
:输出区间 [x, y][x,y] 内每个数的和。
输出格式
输出包含若干行整数,即为所有操作 2 的结果。
输入输出样例
输入 #1复制
5 5 1 5 4 2 3 2 2 4 1 2 3 2 2 3 4 1 1 5 1 2 1 4
输出 #1复制
11 8 20
说明/提示
对于 30\%30% 的数据:n \le 8n≤8,m \le 10m≤10。
对于 70\%70% 的数据:n \le {10}^3n≤103,m \le {10}^4m≤104。
对于 100\%100% 的数据:1 \le n, m \le {10}^51≤n,m≤105。
保证任意时刻数列中任意元素的和在 [-2^{63}, 2^{63})[−263,263) 内。
#include<iostream>
#include<vector>
#include<queue>
#include<cstring>
#include<cmath>
#include<cstdio>
#include<algorithm>
using namespace std;
const int maxn=1e5+10;
typedef long long LL;
struct SegmentTree{
LL l,r;
LL sum,add;
#define l(x) tree[x].l
#define r(x) tree[x].r
#define sum(x) tree[x].sum
#define add(x) tree[x].add
}tree[maxn*4];
LL a[maxn],n,m;
void build(LL p,LL l,LL r)
{
l(p)=l;r(p)=r;
if(l==r) {sum(p)=a[l];return;}
LL mid=(l+r)/2;
build(p*2,l,mid);
build(p*2+1,mid+1,r);
sum(p)=sum(p*2)+sum(p*2+1);
}
void spread(LL p)
{
if(add(p))
{
sum(p*2)+=add(p)*(r(p*2)-l(p*2)+1);
sum(p*2+1)+=add(p)*(r(p*2+1)-l(p*2+1)+1);
add(p*2)+=add(p);add(p*2+1)+=add(p);
add(p)=0;
}
}
void change(LL p,LL l,LL r,LL d)
{
if(l<=l(p)&&r>=r(p))
{
sum(p)+=(LL)d*(r(p)-l(p)+1);
add(p)+=d;
return;
}
spread(p);
LL mid=(l(p)+r(p))/2;
if(l<=mid) change(p*2,l,r,d);///(l,r)
if(r>mid) change(p*2+1,l,r,d);
sum(p)=sum(p*2)+sum(p*2+1);
}
LL ask(LL p,LL l,LL r)
{
if(l<=l(p)&&r>=r(p)) return sum(p);
spread(p);
LL mid=(l(p)+r(p))/2;
LL val=0;
if(l<=mid) val+=ask(p*2,l,r);//(l,r)
if(r>mid) val+=ask(p*2+1,l,r);//(l,r)
return val;
}
int main(void)
{
cin.tie(0);std::ios::sync_with_stdio(false);
cin>>n>>m;
for(int i=1;i<=n;i++) cin>>a[i];
build(1,1,n);
while(m--)
{
int op;LL l,r,d;
cin>>op;
if(op==1)
{
cin>>l>>r>>d;
change(1,l,r,d);
}
else if(op==2){cin>>l>>r;cout<<ask(1,l,r)<<endl;}
}
return 0;
}