这道题一直压着想A了,今天终于A了。在网上看题解的时候发现这样一句话,感觉写的很对。“线段树的标记只用来下传”,也就是说当我们在打标记时,这个标记于我们的当前序列无关,而这就要求我们在更改标签时顺便把它所在的区间给改了。同时我们可以发现乘法标签和加法标签的关系,即一个数是ax+b,当我们给它乘c时,变成acx+bc,这就意味这我们需要将加法和乘法标签全乘c,而当我们加c时,变成ax+b+c,只需将加法标签加c即可。
#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
#define maxn 100009
int n,p,a[maxn],m;
struct tree
{
int l,r;
long long sum,lazy1,lazy2;
}t[3*maxn];
void build(int x,int l,int r)
{
t[x].l=l;t[x].r=r;t[x].lazy1=1;
if (l==r)
{
t[x].sum=a[l];
return;
}
int mid=(l+r)>>1;
build(2*x,l,mid);build(2*x+1,mid+1,r);
t[x].sum=(t[2*x].sum+t[2*x+1].sum)%p;
}
void update(int x)
{
if (t[x].l==t[x].r) return;
if (t[x].lazy1==1&&t[x].lazy2==0) return;
t[2*x].lazy1=(t[x].lazy1*t[2*x].lazy1)%p;
t[2*x].lazy2=(t[x].lazy2+t[2*x].lazy2*t[x].lazy1)%p;
t[2*x+1].lazy1=(t[x].lazy1*t[2*x+1].lazy1)%p;
t[2*x+1].lazy2=(t[x].lazy2+t[2*x+1].lazy2*t[x].lazy1)%p;
t[2*x].sum=(t[2*x].sum*t[x].lazy1+(t[2*x].r-t[2*x].l+1)*t[x].lazy2)%p;
t[2*x+1].sum=(t[2*x+1].sum*t[x].lazy1+(t[2*x+1].r-t[2*x+1].l+1)*t[x].lazy2)%p;
t[x].lazy1=1;t[x].lazy2=0;
}
void add(int x,int l,int r,int z)
{
if (t[x].l==l&&t[x].r==r)
{
t[x].lazy2=(t[x].lazy2+z)%p;
t[x].sum=(t[x].sum+z*(t[x].r-t[x].l+1))%p;
return;
}
update(x);
int mid=(t[x].l+t[x].r)>>1;
if (l>mid) add(2*x+1,l,r,z);
else if (r<=mid) add(2*x,l,r,z);
else
{
add(2*x,l,mid,z);
add(2*x+1,mid+1,r,z);
}
t[x].sum=(t[2*x].sum+t[2*x+1].sum)%p;
}
void multiply(int x,int l,int r,int z)
{
if (t[x].l==l&&t[x].r==r)
{
t[x].lazy1=(t[x].lazy1*z)%p;
t[x].lazy2=(t[x].lazy2*z)%p;
t[x].sum=(t[x].sum*z)%p;
return;
}
update(x);
int mid=(t[x].l+t[x].r)>>1;
if (l>mid) multiply(2*x+1,l,r,z);
else if (r<=mid) multiply(2*x,l,r,z);
else
{
multiply(2*x,l,mid,z);
multiply(2*x+1,mid+1,r,z);
}
t[x].sum=(t[2*x].sum+t[2*x+1].sum)%p;
}
long long query(int x,int l,int r)
{
if (t[x].l==l&&t[x].r==r) return t[x].sum;
update(x);
int mid=(t[x].l+t[x].r)>>1;
if (l>mid) return query(2*x+1,l,r)%p;
else if (r<=mid) return query(2*x,l,r)%p;
else return (query(2*x,l,mid)+query(2*x+1,mid+1,r))%p;
}
int main()
{
scanf("%d%d",&n,&p);
for (int i=1;i<=n;i++) scanf("%d",a+i);
build(1,1,n);
scanf("%d",&m);
for (int i=1;i<=m;i++)
{
int opt,x,y,z;
scanf("%d%d%d",&opt,&x,&y);
if (opt==3) printf("%lld\n",query(1,x,y));
else
{
scanf("%d",&z);
if (opt==2) add(1,x,y,z);
else multiply(1,x,y,z);
}
}
return 0;
}