看到线段树就头大,码量题我板子都背不下来啊啊
写到一半头要炸了(感冒太难受了)于是去摸了个鱼,感觉好点了才继续肝(大雾)
线段树
线段树
(
S
e
g
m
e
n
t
T
r
e
e
)
(Segment\ Tree)
(Segment Tree)是一种基于分治思想的二叉树结构,其基本用途是对序列进行维护,支持查询与修改指令。与按照二进制位(
2
2
2的次幂)进行区间划分的树状数组相比,
线段树是一种更加通用的结构:
1.线段树每个节点都代表一个区间。
2.线段树具有唯一的根节点,代表的区间是整个统计范围,如
[
1
,
N
]
[1,N]
[1,N]。
3.线段树的每一个叶子节点都代表一个长度为
1
1
1的元区间
[
x
,
x
]
[x,x]
[x,x]。
4.对于每个内部节点
[
L
,
R
]
[L,R]
[L,R],它的左子节点是
[
L
,
m
i
d
]
[L,mid]
[L,mid],右子节点是
[
m
i
d
+
1
,
R
]
[mid+1,R]
[mid+1,R],其中
m
i
d
=
(
L
+
R
)
/
2
mid=(L+R)/2
mid=(L+R)/2(向下取整)。
可以发现,除去最后一层,整颗线段树一定是一颗完全二叉树,深度为
O
(
l
o
g
2
N
)
O(log2N)
O(log2N)。
因此,我们可以按照与二叉堆类似的“父子2倍”节点编号方法。
1.根节点编号为
1
1
1。
2.编号为
x
x
x的节点的左子节点编号为
x
∗
2
x*2
x∗2,右子节点编号为:
x
∗
2
+
1
x*2+1
x∗2+1。
我们常用位运算来寻找左右子树:
k
<
<
1
k<<1
k<<1(结点k的左子树下标)
k
<
<
1
∣
1
k<<1|1
k<<1∣1(结点k的右子树下标)
我们用一个
s
t
r
u
c
t
struct
struct数组来保存线段树,树的最后一层节点在数组中保存的位置不是连
续的,那么我们直接空出数组中多余的位置即可。
在理想情况下,共有
N
N
N个叶子节点的满二叉树有
2
N
−
1
2N-1
2N−1个节点
(
N
+
N
/
2
+
N
/
4
+
.
.
.
+
2
+
1
)
(N+N/2+N/4+...+2+1)
(N+N/2+N/4+...+2+1),由于在上述存储方式下,最后一层产生了空余,所以保存线段树的数组长度要 不小于
4
N
4N
4N 才能不越界。
线段树基本操作:
p
u
s
h
u
p
pushup
pushup(将子节点的信息传给父节点),
p
u
s
h
d
o
w
n
pushdown
pushdown(
l
a
z
y
_
t
a
g
lazy\_tag
lazy_tag时下放标记),
b
u
i
l
d
build
build(建树),
m
o
d
i
f
y
modify
modify(根据题目要求进行修改,注意这里参数的改变和
q
u
e
r
y
query
query一样都不要随递归层数进行修改,毕竟要求的东西不变啊),
q
u
e
r
y
query
query(查询,别递归的时候改变查询区间,是找线段树节点区间是否是查询区间的真子集,不要搞混了)
#include <bits/stdc++.h>
#define int long long
using namespace std;
const int N=(int)2e5+50;
int n,m,a[N];
inline int read(){
int cnt=0,f=1;char c=getchar();
while(!isdigit(c)){if(c=='-')f=-f;c=getchar();}
while(isdigit(c)){cnt=(cnt<<1)+(cnt<<3)+(c^48);c=getchar();}
return cnt*f;
}
struct node{
int l,r,maxn;
}tr[N<<1];
inline void pushup(int p){
tr[p].maxn=max(tr[p<<1].maxn,tr[p<<1|1].maxn);
}
inline void build(int p,int l,int r){
tr[p].l=l,tr[p].r=r;
if(tr[p].l==tr[p].r){
tr[p].maxn=a[l];
return;
}
int mid=(l+r)>>1;
build(p<<1,l,mid);build(p<<1|1,mid+1,r);
pushup(p);
}
inline void modify(int p,int pos,int val){
if(tr[p].l==tr[p].r){
tr[p].maxn+=val;return;
}
int mid=(tr[p].l+tr[p].r)>>1;
if(pos<=mid){
modify(p<<1,pos,val);
}
else modify(p<<1|1,pos,val);
pushup(p);
}
inline int query(int p,int l,int r){
int ans=-1;
if(l<=tr[p].l&&r>=tr[p].r) return tr[p].maxn;
int mid=(tr[p].l+tr[p].r)>>1;
if(l<=mid) ans=max(ans,query(p<<1,l,r));
if(r>mid) ans=max(ans,query(p<<1|1,l,r));
return ans;
}
signed main(){
n=read(),m=read();
for(int i=1;i<=n;++i) a[i]=read();
build(1,1,n);
for(int i=1;i<=m;++i){
int opt=read();
if(opt==0){
int i,x;i=read(),x=read();
modify(1,i,x);
}
if(opt==1){
int l=read(),r=read();
cout<<query(1,l,r)<<endl;
}
}
return 0;
}
- 区间修改,单点求值
用到了 l a z y _ t a g lazy\_tag lazy_tag标记。
l a z y _ t a g lazy\_tag lazy_tag:线段树在进行区间更新的时候,为了提高更新的效率,所以每次更新只更新到更新区间完全覆盖线段树结点区间为止,这样就会导致被更新结点的子孙结点的区间得不到需要更新的信息,所以在被更新结点上打上一个标记,称为 l a z y _ t a g lazy\_tag lazy_tag,等到下次访问这个结点的子结点时再将这个标记传递给子结点,所以也可以叫延迟标记。 也就是说递归更新的过程,更新到结点区间为需要更新的区间的真子集不再往下更新,下次若是遇到需要用这下面的结点的信息,再去更新这些结点,所以这样的话使得区间更新的操作和区间查询类似,复杂度为 O ( l o g N ) O(logN) O(logN)。
注意清空已经下放的标记
#include <bits/stdc++.h>
#define int long long
using namespace std;
const int N=(int)1e5+50;
int n,m,a[N];
inline int read(){
int cnt=0,f=1;char c=getchar();
while(!isdigit(c)){if(c=='-')f=-f;c=getchar();}
while(isdigit(c)){cnt=(cnt<<1)+(cnt<<3)+(c^48);c=getchar();}
return cnt*f;
}
struct node{
int l,r,sum,tag;
}tr[N<<2];
inline void pushup(int p){
tr[p].sum=tr[p<<1].sum+tr[p<<1|1].sum;
}
inline void pushdown(int p){
if(tr[p].tag){
tr[p<<1].tag+=tr[p].tag;tr[p<<1|1].tag+=tr[p].tag;
tr[p<<1].sum+=tr[p].tag*(tr[p<<1].r-tr[p<<1].l+1);
tr[p<<1|1].sum+=tr[p].tag*(tr[p<<1|1].r-tr[p<<1|1].l+1);
tr[p].tag=0;
}
}
inline void build(int p,int l,int r){
tr[p].l=l,tr[p].r=r;
if(tr[p].l==tr[p].r){
tr[p].sum=a[l];return;
}
int mid=(l+r)>>1;
build(p<<1,l,mid);build(p<<1|1,mid+1,r);
pushup(p);
}
inline void modify(int p,int l,int r,int d){
if(l<=tr[p].l&&r>=tr[p].r){
tr[p].sum+=d,tr[p].tag+=d;return;
}
pushdown(p);
int mid=(tr[p].l+tr[p].r)>>1;
if(l<=mid) modify(p<<1,l,r,d);
if(r>mid) modify(p<<1|1,l,r,d);
pushup(p);
}
inline int query(int p,int x){
if(tr[p].l==tr[p].r) return tr[p].sum;
int mid=(tr[p].l+tr[p].r)>>1;
pushdown(p);
if(x<=mid) return query(p<<1,x);
else return query(p<<1|1,x);
}
signed main(){
freopen("1.in","r",stdin);
n=read(),m=read();
for(int i=1;i<=n;++i) a[i]=read();
build(1,1,n);
for(int i=1;i<=m;++i){
int opt=read();
if(opt==0){
int l,r,x;l=read(),r=read(),x=read();modify(1,l,r,x);
}
if(opt==1){
int x;x=read();
cout<<query(1,x)<<endl;
}
}
return 0;
}