4.4线段树
关于线段树
- 线段树是一种可以储存一个有限大小数组的区间和,区间最值等具有分配率的运算结果的数据结构
- 与树状数组相比,线段树简单易懂,功能强大,但缺点是空间需求大,代码量大
- 线段树利用二叉树进行存储,而储存不能连续所以需要占用值域四倍的空间,这导致此类题目
n
n
n最大约为
5
e
5
5e5
5e5
- 线段树采用递归遍历
线段树的实现
首先要有的
- 大小为
4
n
4n
4n的tree数组和大小为
n
n
n的原数组
- 遍历时采用
k
,
l
,
r
k,l,r
k,l,r对节点进行描述,即该节点为第
k
k
k个节点,表示的区间为
l
−
r
l-r
l−r
建树
- 由根节点开始,对区间[l,r] 不断取mid,递归访问[l,mid]和[mid+1,r]直至遍历到根节点,将根节点的权值赋值为原数组第l位对应的数字
void build(int k,int l,int r){
if(l==r){
sum[k]=a[l];
return;
}
int mid=(l+r)>>1;
build(2*k,l,mid);
build(2*k+1,mid+1,r);
sum[k]=sum[2*k]+sum[2*k+1];
}
单点修改
- 单点修改时,通过不断向子树递归,直到找到对应节点时,将节点修改,然后回溯修改父亲即可
void change(int k,int l,int r,int x,int v){
if(l==r){
sum[k]+=v;
return;
}
int mid=(l+r)>>1;
if(x<=mid) change(2*k,l,mid,x,v);
else change(2*k+1,mid+1,r,x,v);
sum[k]=sum[2*k]+sum[2*k+1];
}
区间修改
- 一个直接的想法:遍历到区间内所有节点并对区间内所有节点进行修改
- 这种方法时间复杂度显然是无法接受的
- 考虑:每次修改时,是否有必要让区间内每个节点的sum都发生改变?
- 这样,我们引入伟大的懒标记
- 每次对区间进行修改时,递归到查询区间的子区间时,在改区间对应的节点打上一个懒标记,表示该节点的每个子节点的sum都需要加上这个懒标记所记录的值,然后就回溯
- 此后,每访问到一个节点时首先将它的懒标记下放一层,然后在进行其他操作
- 由于对线段树任一的维护只需要用到该节点的两个子节点,所以懒标记的正确性是显然的
void Add(int k, int l, int r, int v)
{
add[k] += v;
sum[k] += (r - l + 1) * v;
}
void pushdown(int k, int l, int r, int mid)
{
if (add[k] == 0)
return;
Add(2 * k, l, mid, add[k]);
Add(2 * k + 1, mid + 1, r, add[k]);
add[k] = 0;
}
void change(int k, int l, int r, int x, int y, int v)
{
if (x <= l && y >= r)
{
Add(k, l, r, v);
return;
}
int mid = (l + r) >> 1;
pushdown(k, l, r, mid);
if (x <= mid)
{
change(2 * k, l, mid, x, y, v);
}
if (y > mid)
{
change(2 * k + 1, mid + 1, r, x, y, v);
}
sum[k] = sum[2 * k] + sum[2 * k + 1];
return;
}
区间查询
- 区间求和时,不断向子树递归,当节点代表区间
[
l
,
r
]
[l,r]
[l,r]被包含于查询区间
[
x
,
y
]
[x,y]
[x,y]时,将ans加上该节点对应权值sum,然后回溯,将ans上传给父亲并更新父节点权值sum,最终结果为左右子树上传的ans之和
int query(int k,int l,int r,int x,int y){
if(x<=l&&y>=r){
return sum[k];
}
int res=0;
int mid=(l+r)>>1;
if(x<=mid){
res+=query(2*k,l,mid,x,y);
}
if(y>mid){
res+=query(2*k+1,mid+1,r,x,y);
}
return res;
}
例题
![](https://i-blog.csdnimg.cn/blog_migrate/b547be66b9a769f2cea74755cbc68030.png)
#include<bits/stdc++.h>
using namespace std;
#define int long long
const int N = 1e5+100;
int sum[4*N],a[N];
int n,m;
void build(int k,int l,int r){
if(l==r){
sum[k]=a[l];
return;
}
int mid=(l+r)>>1;
build(2*k,l,mid);
build(2*k+1,mid+1,r);
sum[k]=sum[2*k]+sum[2*k+1];
}
int query(int k,int l,int r,int x,int y){
if(x<=l&&y>=r){
return sum[k];
}
int res=0;
int mid=(l+r)>>1;
if(x<=mid){
res+=query(2*k,l,mid,x,y);
}
if(y>mid){
res+=query(2*k+1,mid+1,r,x,y);
}
return res;
}
void change(int k,int l,int r,int x,int v){
if(l==r){
sum[k]+=v;
return;
}
int mid=(l+r)>>1;
if(x<=mid) change(2*k,l,mid,x,v);
else change(2*k+1,mid+1,r,x,v);
sum[k]=sum[2*k]+sum[2*k+1];
}
signed main(){
scanf("%lld%lld",&n,&m);
for(int i=1;i<=m;i++){
int k,a,b;
scanf("%lld%lld%lld",&k,&a,&b);
if(k==0){
change(1,1,n,a,b);
}
if(k==1){
printf("%lld\n",query(1,1,n,a,b));
}
}
return 0;
}
#include <bits/stdc++.h>
using namespace std;
#define int long long
const int N = 1e6 + 100;
int sum[4 * N], a[N], add[4 * N];
int n, q;
void build(int k, int l, int r) {
if (l == r) {
sum[k] = a[l];
return;
}
int mid = (l + r) >> 1;
build(2 * k, l, mid);
build(2 * k + 1, mid + 1, r);
sum[k] = sum[2 * k] + sum[2 * k + 1];
return;
}
void Add(int k, int l, int r, int v) {
add[k] += v;
sum[k] += (r - l + 1) * v;
}
void pushdown(int k, int l, int r, int mid) {
if (add[k] == 0)
return;
Add(2 * k, l, mid, add[k]);
Add(2 * k + 1, mid + 1, r, add[k]);
add[k] = 0;
}
void change(int k, int l, int r, int x, int y, int v) {
if (x <= l && y >= r) {
Add(k, l, r, v);
return;
}
int mid = (l + r) >> 1;
pushdown(k, l, r, mid);
if (x <= mid) {
change(2 * k, l, mid, x, y, v);
}
if (y > mid) {
change(2 * k + 1, mid + 1, r, x, y, v);
}
sum[k] = sum[2 * k] + sum[2 * k + 1];
return;
}
int query(int k, int l, int r, int x, int y) {
if (x <= l && y >= r)
return sum[k];
int res = 0;
int mid = (l + r) >> 1;
pushdown(k, l, r, mid);
if (x <= mid) {
res += query(2 * k, l, mid, x, y);
}
if (y > mid) {
res += query(2 * k + 1, mid + 1, r, x, y);
}
return res;
}
signed main() {
scanf("%lld%lld", &n, &q);
for (int i = 1; i <= n; i++) {
scanf("%lld", &a[i]);
}
build(1, 1, n);
for (int i = 1; i <= q; i++) {
int o, l, r, x;
scanf("%lld%lld%lld", &o, &l, &r);
if (o == 1) {
scanf("%lld", &x);
change(1, 1, n, l, r, x);
} else if (o == 2) {
printf("%lld\n", query(1, 1, n, l, r));
}
}
return 0;
}
- 思路:
- 一个直接的想法:将维护的区间和改为维护区间最大值,那么如何实现呢?
- 合并:因为最大值运算具有结合律,所以直接用取max代替+即可
- 查询:
- 考虑查询区间可能出现的位置:从左开始,跨过中间,直到右边
- 所以,对每个节点,维护3个信息:区间最大值max,从左开始的最大值lm,到右结束的最大值rm
- 对这三个信息合并时,max为左右max与左的rm+右的lm的max,lm为max(lm,左.max+右.lm),rm为max(rm,左.rm+右.max)
- 其余步骤按区间修改区间查询的方法做即可
- 代码
#include <bits/stdc++.h>
using namespace std;
#define int long long
const int N = 5e5 + 10;
struct point {
int sum, maxx, lm, rm;
point operator+(const point& a) const {
point tmp;
tmp.sum = sum + a.sum;
tmp.maxx = max(max(maxx, a.maxx), rm + a.lm);
tmp.lm = max(lm, sum + a.lm);
tmp.rm = max(a.rm, rm + a.sum);
return tmp;
}
} tree[4 * N];
int n, m, a[N];
void build(int k, int l, int r) {
if (l == r) {
tree[k].sum = a[l];
tree[k].maxx = a[l];
tree[k].lm = a[l];
tree[k].rm = a[l];
return;
}
int mid = (l + r) >> 1;
build(2 * k, l, mid);
build(2 * k + 1, mid + 1, r);
tree[k] = tree[2 * k] + tree[2 * k + 1];
return;
}
point query(int k, int l, int r, int x, int y) {
if (x <= l && y >= r)
return tree[k];
int mid = (l + r) >> 1;
if (y <= mid)
return query(2 * k, l, mid, x, y);
if (x > mid)
return query(2 * k + 1, mid + 1, r, x, y);
return query(2 * k, l, mid, x, y) + query(2 * k + 1, mid + 1, r, x, y);
}
void change(int k, int l, int r, int x, int v) {
if (l == r) {
tree[k].sum = tree[k].maxx = tree[k].lm = tree[k].rm = v;
return;
}
int mid = (l + r) >> 1;
if (x <= mid)
change(2 * k, l, mid, x, v);
else if (x > mid)
change(2 * k + 1, mid + 1, r, x, v);
tree[k] = tree[2 * k] + tree[2 * k + 1];
return;
}
signed main() {
scanf("%lld%lld", &n, &m);
for (int i = 1; i <= n; i++) {
scanf("%lld", &a[i]);
}
build(1, 1, n);
for (int i = 1; i <= m; i++) {
int k, a, b;
scanf("%lld%lld%lld", &k, &a, &b);
if (k == 1) {
if (a > b)
swap(a, b);
printf("%lld\n", query(1, 1, n, a, b).maxx);
} else if (k == 2) {
change(1, 1, n, a, b);
}
}
return 0;
}
- 区间更改区间求和
- 显然,乘法会影响加法,但加法不会影响乘法,所以,我们维护两个懒标记,plu和mul,分别记录乘数和加数,在dfs是先pushdown懒标记mul,后push懒标记plu
- 仿照区间加,区间求和维护即可
#include<bits/stdc++.h>
using namespace std;
#define int long long
const int N = 1e5+10;
int n,q,m,a[N];
struct t{
int sum,mul,plu;
}tr[N*4];
void pushup(int k){
tr[k].sum=tr[k<<1].sum+tr[k<<1|1].sum;
tr[k].sum%=m;
}
void build(int k,int l,int r){
tr[k].mul=1;
if(l==r){
tr[k].sum=a[l];
return ;
}
int mid=(l+r)>>1;
build(k<<1,l,mid);
build(k<<1|1,mid+1,r);
pushup(k);
return;
}
void pushdown(int k,int l,int r){
int k1=k<<1,k2=k<<1|1,mu=tr[k].mul,p=tr[k].plu;
int mid=(l+r)>>1;
tr[k1].sum*=mu;tr[k1].sum%=m;
tr[k1].sum+=p*(mid-l+1)%m;tr[k1].sum%=m;
tr[k2].sum*=mu;tr[k2].sum%=m;
tr[k2].sum+=p*(r-mid)%m;tr[k2].sum%=m;
tr[k1].mul*=mu;tr[k1].mul%=m;
tr[k2].mul*=mu;tr[k2].mul%=m;
tr[k1].plu*=mu;tr[k1].plu%=m;
tr[k1].plu+=p;tr[k1].plu%=m;
tr[k2].plu*=mu;tr[k2].plu%=m;
tr[k2].plu+=p;tr[k2].plu%=m;
tr[k].mul=1;tr[k].plu=0;
return ;
}
void modify_m(int k,int l,int r,int x,int y,int mu){
if(x>r||y<l) return;
if(x<=l&&r<=y){
tr[k].sum*=mu;
tr[k].sum%=m;
tr[k].plu*=mu;
tr[k].plu%=m;
tr[k].mul*=mu;
tr[k].mul%=m;
return ;
}
int mid=(l+r)>>1;
pushdown(k,l,r);
modify_m(k<<1,l,mid,x,y,mu);
modify_m(k<<1|1,mid+1,r,x,y,mu);
pushup(k);
return ;
}
void modify_p(int k,int l,int r,int x,int y,int p){
if(x>r||y<l) return ;
if(x<=l&&r<=y){
tr[k].sum+=p*(r-l+1);tr[k].sum%=m;
tr[k].plu+=p;tr[k].plu%=m;
return ;
}
int mid=(l+r)>>1;
pushdown(k,l,r);
modify_p(k<<1,l,mid,x,y,p);
modify_p(k<<1|1,mid+1,r,x,y,p);
pushup(k);
return ;
}
int query(int k,int l,int r,int x,int y){
if(x>r||y<l) return 0;
if(x<=l&&r<=y){
return tr[k].sum%m;
}
int mid=(l+r)>>1;
pushdown(k,l,r);
int s1,s2;
s1=query(k<<1,l,mid,x,y);
s2=query(k<<1|1,mid+1,r,x,y);
return (s1+s2)%m;
}
signed main(){
scanf("%lld%lld%lld",&n,&q,&m);
for(int i=1;i<=n;i++){
scanf("%lld",&a[i]);
}
build(1,1,n);
for(int i=1;i<=q;i++){
int op,x,y,k;
scanf("%lld",&op);
if(op==1){
scanf("%lld%lld%lld",&x,&y,&k);
modify_m(1,1,n,x,y,k);
}else if(op==2){
scanf("%lld%lld%lld",&x,&y,&k);
modify_p(1,1,n,x,y,k);
}else if(op==3){
scanf("%lld%lld",&x,&y);
printf("%lld\n",query(1,1,n,x,y));
}
}
return 0;
}