和这道题极为相似的还有一道题如下
给出一个长度为 n n n 的序列 a a a ,有 m m m 个操作 ( 1 ≤ n , m ≤ 1 0 5 ; 1 ≤ a [ i ] ≤ 1 0 9 ) (1 \le n,m \le 10^5;1 \le a[i] \le 10^9) (1≤n,m≤105;1≤a[i]≤109) ,分为以下三种:
-
1 l r 1 \ l \ r 1 l r:查询序列中区间 [ l , r ] [l,r] [l,r] 的和。 ( 1 ≤ l ≤ r ≤ n ) (1 \le l \le r \le n) (1≤l≤r≤n)
-
2 l r x 2 \ l \ r \ x 2 l r x:把序列中区间 [ l , r ] [l,r] [l,r] 的所有数 m o d x \mod \ x mod x。 ( 1 ≤ l ≤ r ≤ n ; 1 ≤ x ≤ 1 0 9 ) (1 \le l \le r \le n;1 \le x \le 10^9) (1≤l≤r≤n;1≤x≤109)
-
3 k x 3 \ k \ x 3 k x:把序列中位置为 k k k 的数改为 x x x。 ( 1 ≤ k ≤ n ; 1 ≤ x ≤ 1 0 9 ) (1 \le k \le n;1 \le x \le 10^9) (1≤k≤n;1≤x≤109)
输出每个操作 1 1 1 得到的答案。
本来对于线段树的区间修改我们优先考虑懒标记的,但是这道题吧会出现一个非常棘手的现象,使用懒标记反而不好处理
因为
(
a
+
b
)
m
o
d
x
=
a
m
o
d
x
+
b
m
o
d
x
(a+b)\mod x=a\mod x+b\mod x
(a+b)modx=amodx+bmodx
举个栗子:a=5,b=1,x=3
因此我们不能使用懒标记先对大区间求模,再递归到小区间继续求模。
因此,这两道题我们都使用到了这个性质:
因为对于求模都有的一个性质如果
a
m
o
d
x
,
a
<
x
a\mod x,a< x
amodx,a<x那么这个式子不成立,因此我们在线段树开一个数值来维护该区间当前的最大值,如果最大值大于或等于我们现在的x那么就继续递归,反之就不用。
#include<iostream>
using namespace std;
const int N=1e5+5;
typedef long long ll;
ll w[N];
struct node{
int l,r;
ll maxx;
ll sum;
}tr[N*4];
void pushup(node &fa,node &left,node &right){
fa.sum=left.sum+right.sum;
fa.maxx=max(left.maxx,right.maxx);
}
void pushup(int p){
pushup(tr[p],tr[p<<1],tr[p<<1|1]);
}
void build(int p,int l,int r){
tr[p]={l,r};
if(l==r){
tr[p].maxx=w[l];
tr[p].sum=w[l];
return;
}
int mid=l+r>>1;
build(p<<1,l,mid);
build(p<<1|1,mid+1,r);
pushup(p);
}
void modd(int p,ll x){
if(tr[p].maxx<x)return;
if(tr[p].l==tr[p].r){
tr[p].sum%=x;
tr[p].maxx%=x;
return;
}
//int mid=tr[p].l+tr[p].r>>1;
modd(p<<1,x);
modd(p<<1|1,x);
pushup(p);
return;
}
void changemod(int p,int l,int r,ll x){
if(tr[p].l>=l&&tr[p].r<=r){
modd(p,x);
return;
}
if(tr[p].l>r||tr[p].r<l)return;
//int mid=tr[p].l+tr[p].r>>1;
changemod(p<<1|1,l,r,x);
changemod(p<<1,l,r,x);
pushup(p);
}
void change(int p,int k,ll x){
if(tr[p].l==tr[p].r&&tr[p].l==k){
tr[p].sum=tr[p].maxx=x;
return ;
}
int mid=tr[p].l+tr[p].r>>1;
if(k<=mid)change(p<<1,k,x);
else change(p<<1|1,k,x);
pushup(p);
}
ll query(int p,int l,int r){
//cout<<tr[p].l<<" "<<tr[p].r<<"\n";
if(tr[p].l>=l&&tr[p].r<=r){
return tr[p].sum;
}
if(tr[p].l>r||tr[p].r<l)return 0;
//int mid=tr[p].l+tr[p].r>>1;
ll k=0;
k+=query(p<<1,l,r);
k+=query(p<<1|1,l,r);
return k;
}
int main(){
int n,m;
cin>>n>>m;
for(int i=1;i<=n;i++)cin>>w[i];
build(1,1,n);
while(m--){
int q,l,r;ll k,x;
cin>>q;
if(q==1){
cin>>l>>r;
cout<<query(1,l,r)<<"\n";
}else if(q==2){
cin>>l>>r>>x;
changemod(1,l,r,x);
}else{
cin>>k>>x;
change(1,k,x);
}
}
}