使用场景
-
对数列进行区间询问(包括最值、求和、乘积等询问)。
-
对数列进行区间修改(统一赋值、增减)。
使用思想
分治。
详解
手造一段数列:
2 5 9 1 7 6 5 3
给定 m 次操作:
type1. 将区间 [L,R]加上 val。
type2. 询问区间[L,R] 的元素和。
算法实现
-
建一棵线段树
-
查询函数
-
修改函数
时间复杂度分析
-
线段树结点编号达到 4×n。
-
build
O(n) -
query
O(log n) -
update
O(n)(未优化)
未优化代码
#include<bits/stdc++.h>
#define int long long
using namespace std;
const int N=2e5+5;
int a[N],tree[4*N], n, m;
void pushup(int cur)
{
tree[cur]=tree[2*cur]+tree[2*cur+1];
return ;
}
void build(int cur, int lt, int rt)
{
if(lt==rt)
{
tree[cur]=a[lt];
return ;
}
int mid=(lt+rt)>>1;
build(cur*2,lt,mid);
build(cur*2+1,mid+1,rt);
pushup(cur);
return ;
}
int query(int cur, int lt, int rt, int qx, int qy)
{
if(qy<lt||qx>rt)
{
return 0;
}
if(qx<=lt&&rt<=qy)
{
return tree[cur];
}
int mid=lt+rt>>1;
return query(cur*2,lt,mid,qx,qy)+query(cur*2+1,mid+1,rt,qx,qy);
}
void update(int cur, int lt, int rt, int qx, int qy, int val)
{
if(qy<lt||qx>rt)
{
return ;
}
if(lt==rt)
{
tree[cur]+=val;
return ;
}
int mid=lt+rt>>1;
update(cur*2,lt,mid,qx,qy,val);
update(cur*2+1,mid+1,rt,qx,qy,val);
pushup(cur);
}
signed main()
{
cin>>n>>m;
for(int i=1;i<=n;i++)
{
cin>>a[i];
}
build(1,1,n);
while(m--)
{
int opt, x, y, val;
cin>>opt>>x>>y;
if(opt==1)
{
cin>>val;
update(1,1,n,x,y,val);
}
if(opt==2)
{
cout<<query(1,1,n,x,y)<<"\n";
}
}
}
线段树的优化
很容易发现,单次修改的时间复杂度很慢。
考虑优化,采用懒标记。
性质:线段树的修改是为询问而服务的。
维护标记 tag_{cur} 表示结点 cur 需要修改的值。
修改后 update
函数单次时间复杂度可达 O(log n)
优化后代码
#include<bits/stdc++.h>
#define int long long
using namespace std;
const int N=1e5+5;
int a[N], n, m, tag[4*N], tree[4*N];
void pushup(int cur)
{
tree[cur]=tree[cur*2]+tree[cur*2+1];
return ;
}
void addtag(int cur, int lt, int rt, int val)
{
tag[cur]+=val;
tree[cur]+=(rt-lt+1)*val;
return ;
}
void pushdown(int cur, int lt, int rt)
{
if(tag[cur]==0)
{
return ;
}
int mid=lt+rt>>1;
addtag(cur*2,lt,mid,tag[cur]);
addtag(cur*2+1,mid+1,rt,tag[cur]);
tag[cur]=0;
return ;
}
void build(int cur, int lt, int rt)
{
if(lt==rt)
{
tree[cur]=a[lt];
return ;
}
int mid=lt+rt>>1;
build(cur*2,lt,mid);
build(cur*2+1,mid+1,rt);
pushup(cur);
return ;
}
int query(int cur, int lt, int rt, int qx, int qy)
{
if(qy<lt||qx>rt)
{
return 0;
}
if(qx<=lt&&rt<=qy)
{
return tree[cur];
}
pushdown(cur,lt,rt);
int mid=lt+rt>>1;
return query(cur*2,lt,mid,qx,qy)+query(cur*2+1,mid+1,rt,qx,qy);
}
void update(int cur, int lt, int rt, int qx, int qy, int val)
{
if(qy<lt||qx>rt)
{
return ;
}
if(qx<=lt&&rt<=qy)
{
addtag(cur,lt,rt,val);
return ;
}
pushdown(cur,lt,rt);
int mid=lt+rt>>1;
update(cur*2,lt,mid,qx,qy,val);
update(cur*2+1,mid+1,rt,qx,qy,val);
pushup(cur);
return ;
}
signed main()
{
cin>>n>>m;
for(int i=1;i<=n;i++)
{
cin>>a[i];
}
build(1,1,n);
while(m--)
{
int opt, x, y, val;
cin>>opt>>x>>y;
if(opt==1)
{
cin>>val;
update(1,1,n,x,y,val);
}
else
{
cout<<query(1,1,n,x,y)<<"\n";
}
}
return 0;
}
带乘法的线段树
多维护一个乘法标记,随加法标记更新,注意运算顺序。
代码
#include<bits/stdc++.h>
#define int long long
using namespace std;
const int N=1e5+5;
int a[N], n, m, tag[4*N], tree[4*N], mod, mul[4*N];
void pushup(int cur)
{
tree[cur]=tree[cur*2]%mod+tree[cur*2+1]%mod;
return ;
}
void addtag(int cur, int lt, int rt, int val)
{
tag[cur]+=val;
tree[cur]+=(rt-lt+1)*val%mod;
return ;
}
void addtag1(int cur, int lt, int rt, int val)
{
tag[cur]=tag[cur]*val%mod;
mul[cur]=mul[cur]*val%mod;
tree[cur]=tree[cur]*val%mod;
return ;
}
void pushdown(int cur, int lt, int rt)
{
if(tag[cur]==0&&mul[cur]==1)
{
return ;
}
int mid=(lt+rt)>>1;
addtag1(cur*2,lt,mid,mul[cur]);
addtag1(cur*2+1,mid+1,rt,mul[cur]);
addtag(cur*2,lt,mid,tag[cur]);
addtag(cur*2+1,mid+1,rt,tag[cur]);
tag[cur]=0;
mul[cur]=1;
return ;
}
void build(int cur, int lt, int rt)
{
if(lt==rt)
{
tree[cur]=a[lt];
return ;
}
int mid=lt+rt>>1;
build(cur*2,lt,mid);
build(cur*2+1,mid+1,rt);
pushup(cur);
return ;
}
int query(int cur, int lt, int rt, int qx, int qy)
{
if(qy<lt||qx>rt)
{
return 0;
}
if(qx<=lt&&rt<=qy)
{
return tree[cur];
}
pushdown(cur,lt,rt);
int mid=lt+rt>>1;
return query(cur*2,lt,mid,qx,qy)+query(cur*2+1,mid+1,rt,qx,qy);
}
void update(int cur, int lt, int rt, int qx, int qy, int val)
{
if(qy<lt||qx>rt)
{
return ;
}
if(qx<=lt&&rt<=qy)
{
addtag(cur,lt,rt,val);
return ;
}
pushdown(cur,lt,rt);
int mid=lt+rt>>1;
update(cur*2,lt,mid,qx,qy,val);
update(cur*2+1,mid+1,rt,qx,qy,val);
pushup(cur);
return ;
}
void update1(int cur, int lt, int rt, int qx, int qy, int val)
{
if(qy<lt||qx>rt)
{
return ;
}
if(qx<=lt&&rt<=qy)
{
addtag1(cur,lt,rt,val);
return ;
}
pushdown(cur,lt,rt);
int mid=lt+rt>>1;
update1(cur*2,lt,mid,qx,qy,val);
update1(cur*2+1,mid+1,rt,qx,qy,val);
pushup(cur);
return ;
}
signed main()
{
cin>>n>>m>>mod;
for(int i=1;i<=n;i++)
{
cin>>a[i];
}
for(int i=1;i<=4*n;i++)
{
mul[i]=1;
}
build(1,1,n);
while(m--)
{
int opt, x, y, val;
cin>>opt>>x>>y;
if(opt==2)
{
cin>>val;
update(1,1,n,x,y,val);
}
else if(opt==1)
{
cin>>val;
update1(1,1,n,x,y,val);
}
else
{
cout<<query(1,1,n,x,y)%mod<<"\n";
}
}
return 0;
}