线段树模板题,含lazytag的线段树码量本身就比较大,再加入乘法标记,还要考虑先乘后加的问题,本蒟蒻一调就是几个小时。
P3373 【模板】线段树 2https://www.luogu.com.cn/problem/P3373
#include <bits/stdc++.h>
#define sc(n) scanf("%d", &n)
#define int long long
#define endl '\n'
#define FAST ios::sync_with_stdio(false);
typedef long long ll;
using namespace std;
int n, m, mod;
const int N = 1e5 + 10;
struct node
{
int l, r, v, mul, add;
} t[5*N];
int a[N];
void build(int p, int l, int r)
{
t[p].mul = 1;t[p].add = 0;t[p].l = l, t[p].r = r;
if (l == r)
{
t[p].v = a[l]%mod;
return;
}
int mid = l + r >> 1;
build(p << 1, l, mid);
build(p << 1 | 1, mid + 1, r);
t[p].v = (t[p << 1].v + t[p << 1 | 1].v)%mod;
}
void push_down(int p)
{
t[p << 1].v = (ll)(t[p].mul * t[p << 1].v + ((t[p << 1].r - t[p << 1].l + 1) * t[p].add)%mod)%mod;
t[p << 1 | 1].v = (ll)(t[p].mul * t[p << 1 | 1].v%mod + ((t[p << 1 | 1].r - t[p << 1 | 1].l+1) * t[p].add)%mod)%mod;
t[p << 1].mul = (ll)(t[p << 1].mul * t[p].mul) % mod;
t[p << 1 | 1].mul = (ll)(t[p << 1 | 1].mul * t[p].mul)% mod;
t[p << 1].add = (ll)((t[p << 1].add * t[p].mul)%mod+t[p].add) % mod;
t[p << 1 | 1].add =(ll)((t[p << 1 | 1].add * t[p].mul)%mod+t[p].add) % mod;
t[p].mul = 1, t[p].add = 0;
return;
}
int query(int p, int l, int r)
{
if (t[p].l >= l && t[p].r <= r)
{
return t[p].v;
}
push_down(p);
int res = 0;
if (t[p<<1].r >= l)
res = (res+query(p << 1, l, r))%mod;
if (t[p<<1|1].l <= r)
res = (res+query(p << 1 | 1, l, r))%mod;
return res;
}
void add(int p, int l, int r, int val)
{
if (t[p].l >= l && t[p].r <= r)
{
t[p].v = (t[p].v+((t[p].r - t[p].l + 1) * val)%mod)% mod;
t[p].add =(val+t[p].add)% mod;
return;
}
push_down(p);
int mid = l + r >> 1;
if (t[p<<1].r >= l)
add(p << 1, l, r, val);
if (t[p<<1|1].l <= r)
add(p << 1 | 1, l, r, val);
t[p].v = (t[p << 1].v + t[p << 1 | 1].v)%mod;
}
void mul(int p, int l, int r, int val)
{
if (t[p].l >= l && t[p].r <= r)
{
t[p].v =(t[p].v*val) % mod;
t[p].mul = (t[p].mul*val) % mod;
t[p].add = (t[p].add*val) % mod;
return;
}
push_down(p);
int mid = l + r >> 1;
if (t[p<<1].r >= l)
mul(p << 1, l, r, val);
if (t[p<<1|1].l <= r)
mul(p << 1 | 1, l, r, val);
t[p].v = (t[p << 1].v + t[p << 1 | 1].v)%mod;
}
signed main()
{
FAST;
cin >> n >> m >> mod;
for (int i = 1; i <= n; i++)
cin >> a[i];
build(1, 1, n);
int op, x, y;
while (m--)
{
cin >> op;
if (op == 1)
{
int k;
cin >> x >> y >> k;
mul(1, x, y, k);
}
else if (op == 2)
{
int k;
cin >> x >> y >> k;
add(1, x, y, k);
}
else if (op == 3)
{
cin >> x >> y;
cout << query(1, x, y) << endl;
}
}
return 0;
}