P3373 【模板】线段树 2
题目描述
如题,已知一个数列,你需要进行下面三种操作:
1.将某区间每一个数乘上x
2.将某区间每一个数加上x
3.求出某区间每一个数的和
输入格式
第一行包含三个整数N、M、P,分别表示该数列数字的个数、操作的总个数和模数。
第二行包含N个用空格分隔的整数,其中第i个数字表示数列第i项的初始值。
接下来M行每行包含3或4个整数,表示一个操作,具体如下:
操作1: 格式:1 x y k 含义:将区间[x,y]内每个数乘上k
操作2: 格式:2 x y k 含义:将区间[x,y]内每个数加上k
操作3: 格式:3 x y 含义:输出区间[x,y]内每个数的和对P取模所得的结果
输出格式
输出包含若干行整数,即为所有操作3的结果。
输入输出样例
输入 #1复制
5 5 38
1 5 4 2 3
2 1 4 1
3 2 5
1 2 4 2
2 3 5 5
3 1 4
输出 #1复制
17
2
说明/提示
时空限制:1000ms,128M
数据规模:
对于30%的数据:N<=8,M<=10
对于70%的数据:N<=1000,M<=10000
对于100%的数据:N<=100000,M<=100000
想起来好久没有更博了——刷了道模板题熟悉下线段树
不多说直接贴代码——
#include<bits/stdc++.h>
using namespace std;
typedef long long ll ;
const int N = 1e5+5;
int n, a[N], m, P;
struct node{
int l, r;
ll add, mul, sum;
void update(int Mul, int Add)
{
sum = (1ll*sum*Mul + Add*1ll*(r-l+1))%P;
add = (add*Mul + Add) % P;
mul = (mul * Mul) % P;
}
}tree[N<<2];
void push_up(int x)
{
tree[x].sum = (tree[x<<1].sum + tree[x<<1|1].sum) % P;
}
void push_down(int x)
{
int Mul = tree[x].mul;
int Add = tree[x].add;
if(Mul!=1 || Add!=0){
tree[x<<1].update(Mul,Add);
tree[x<<1|1].update(Mul,Add);
tree[x].add = 0, tree[x].mul = 1;
}
}
void build(int x,int l, int r)
{
tree[x].l = l, tree[x].r = r;
tree[x].add = tree[x].sum = 0,tree[x].mul = 1; //注意乘法初始值为1
//emmm调了好久才发现
if(l == r){
tree[x].sum = a[l];
}else{
int mid = l+r >> 1;
build(x<<1,l,mid);
build(x<<1|1,mid+1,r);
push_up(x);
}
}
void update(int x, int l, int r, ll val,int fg)
{
int L = tree[x].l, R = tree[x].r;
if(l<=L && R<=r){
if(fg==1) tree[x].update(val,0);
else tree[x].update(1,val);
}else{
push_down(x);
int mid = (L+R)/2;
if(mid>= l) update(x<<1,l,r,val,fg);
if(r > mid) update(x<<1|1,l,r,val,fg);
push_up(x);
}
}
ll query(int x, int l, int r)
{
int L = tree[x].l, R = tree[x].r;
if(l<=L && R<=r){
return tree[x].sum%P;
}else{
push_down(x);
ll Ans = 0;
int mid = (L+R)/2;
if(mid>= l) Ans = (Ans + query(x<<1,l,r)) % P;
if(r > mid) Ans = (Ans + query(x<<1|1,l,r)) % P;
push_up(x);
//cout<<x<<"----"<<Ans<<endl;
return Ans;
}
}
int main()
{
scanf("%d%d%d",&n,&m,&P);
for(int i=1;i<=n;i++) scanf("%d",&a[i]);
build(1,1,n);
for(int i=1;i<=m;i++)
{
int op, l, r;
scanf("%d%d%d",&op,&l,&r);
if(op != 3){
ll val; scanf("%lld",&val);
update(1,l,r,val,op);
}
else printf("%lld\n",query(1,l,r));
}
return 0;
}