题目描述
如题,已知一个数列,你需要进行下面三种操作:
1.将某区间每一个数乘上 x
2.将某区间每一个数加上 x
3.求出某区间每一个数的和
分析:显然这题比p3372更复杂一点(主要体现在既有加又有乘的操作)
那么对于我们的add,mul的tag就有一个先后顺序
即每次更新mul的时候我们都要把之前的add乘上这个
但是更新add的时候不考虑mul
#include <bits/stdc++.h>
#define MAXN 100010
#define ll long long
using namespace std;
int n, m, mod;
int a[MAXN];
struct node
{
ll sum, add, mul;
int l, r;
} s[MAXN * 4];
void update(int pos)
{
s[pos].sum = (s[pos << 1].sum + s[pos << 1 | 1].sum) % mod;
return;
}
void pushdown(int pos)
{
s[pos << 1].sum = (s[pos << 1].sum * s[pos].mul + s[pos].add * (s[pos << 1].r - s[pos << 1].l + 1)) % mod;
s[pos << 1 | 1].sum = (s[pos << 1 | 1].sum * s[pos].mul + s[pos].add * (s[pos << 1 | 1].r - s[pos << 1 | 1].l + 1)) % mod;
s[pos << 1].mul = (s[pos << 1].mul * s[pos].mul) % mod;
s[pos << 1 | 1].mul = (s[pos << 1 | 1].mul * s[pos].mul) % mod;
s[pos << 1].add = (s[pos << 1].add * s[pos].mul + s[pos].add) % mod;
s[pos << 1 | 1].add = (s[pos << 1 | 1].add * s[pos].mul + s[pos].add) % mod;
s[pos].add = 0;
s[pos].mul = 1;
return;
}
void build(int pos, int l, int r)
{
s[pos].l = l;
s[pos].r = r;
s[pos].mul = 1;
if (l == r)
{
s[pos].sum = a[l] % mod;
return;
}
int mid = (l + r) >> 1;
build(pos << 1, l, mid);
build(pos << 1 | 1, mid + 1, r);
update(pos);
return;
}
void mul(int pos, int x, int y, int k)
{
if (x <= s[pos].l && s[pos].r <= y)
{
s[pos].add = (s[pos].add * k) % mod;
s[pos].mul = (s[pos].mul * k) % mod;
s[pos].sum = (s[pos].sum * k) % mod;
return;
}
pushdown(pos);
int mid = (s[pos].l + s[pos].r) >> 1;
if (x <= mid)
mul(pos << 1, x, y, k);
if (y > mid)
mul(pos << 1 | 1, x, y, k);
update(pos);
return;
}
void add(int pos, int x, int y, int k)
{
if (x <= s[pos].l && s[pos].r <= y)
{
s[pos].add = (s[pos].add + k) % mod;
s[pos].sum = (s[pos].sum + k * (s[pos].r - s[pos].l + 1)) % mod;
return;
}
pushdown(pos);
int mid = (s[pos].l + s[pos].r) >> 1;
if (x <= mid)
add(pos << 1, x, y, k);
if (y > mid)
add(pos << 1 | 1, x, y, k);
update(pos);
return;
}
ll ask(int pos, int x, int y)
{
if (x <= s[pos].l && s[pos].r <= y)
{
return s[pos].sum;
}
pushdown(pos);
ll val = 0;
int mid = (s[pos].l + s[pos].r) >> 1;
if (x <= mid)
val = (val + ask(pos << 1, x, y)) % mod;
if (y > mid)
val = (val + ask(pos << 1 | 1, x, y)) % mod;
return val;
}
int main()
{
scanf("%d%d%d", &n, &m, &mod);
for (int i = 1; i <= n; i++)
{
scanf("%d", &a[i]);
}
build(1, 1, n);
for (int i = 1; i <= m; i++)
{
int opt, x, y;
scanf("%d%d%d", &opt, &x, &y);
if (opt == 1)
{
int k;
scanf("%d", &k);
mul(1, x, y, k);
}
if (opt == 2)
{
int k;
scanf("%d", &k);
add(1, x, y, k);
}
if (opt == 3)
{
printf("%lld\n", ask(1, x, y));
}
}
return 0;
}