Transformation
[Link](Problem - 4578 (dingbacode.com))
题意
一个数组,实现区间加,区间乘,区间变,查询区间每个数1、2、3次的和。
题解
用线段树维护,维护三个sum分别代表1、2、3次的值,用相应的公式相互递推即可,再维护三个懒标记对应区间加、区间乘、区间变。区间修改的时候要先变再乘再加即可。
Code
#include <iostream>
#include <algorithm>
#include <cstring>
#include <cstdio>
#include <set>
#include <queue>
#include <vector>
#include <map>
#include <bitset>
#include <unordered_map>
#include <cmath>
#include <stack>
#include <iomanip>
#include <deque>
#include <sstream>
#define x first
#define y second
using namespace std;
typedef long double ld;
typedef long long LL;
typedef pair<int, int> PII;
typedef pair<double, double> PDD;
typedef unsigned long long ULL;
const int N = 1e5 + 10, M = 2 * N, INF = 0x3f3f3f3f, mod = 10007;
const double eps = 1e-8;
int dx[] = {-1, 0, 1, 0}, dy[] = {0, 1, 0, -1};
int h[N], e[M], ne[M], w[M], idx;
void add(int a, int b, int v = 0) {
e[idx] = b, w[idx] = v, ne[idx] = h[a], h[a] = idx ++;
}
int n, m, k;
struct Node {
int l, r;
int sum1, sum2, sum3;
int add, mul, tob;
}tr[N << 2];
int a[N];
void pushup(int u) {
tr[u].sum1 = (tr[u << 1].sum1 + tr[u << 1 | 1].sum1) % mod;
tr[u].sum2 = (tr[u << 1].sum2 + tr[u << 1 | 1].sum2) % mod;
tr[u].sum3 = (tr[u << 1].sum3 + tr[u << 1 | 1].sum3) % mod;
}
void eval(Node& op, int add, int mul, int lazy) { //需要分类pushdown
int len = op.r - op.l + 1;
if (lazy) {
lazy %= mod;
op.sum1 = lazy * len % mod;
op.sum2 = lazy * lazy % mod * len % mod;
op.sum3 = lazy * lazy % mod * lazy % mod * len % mod;
op.add = 0, op.mul = 1; op.tob = lazy;
}
// 保证次序要线成后加,除了原本的区间和要被乘,lazyadd 也要被乘
// (a + b) ^ 2 = a ^ 2 + b ^ 2 + 2ab
if (mul != 1) { //对于乘法和加法的标签处理, 需要遵循先乘后加的原则
op.sum1 = op.sum1 * mul % mod;
op.sum2 = op.sum2 * mul % mod * mul % mod;
op.sum3 = op.sum3 * mul % mod * mul % mod * mul % mod;
op.add = mul * op.add % mod;
op.mul = op.mul * mul % mod;
}
if (add) {
// 单个时(a + add) ^ 3 一共 len 个线性递推维护
op.sum3 = (op.sum3 + len * add % mod * add % mod * add % mod + 3 * op.sum2 % mod * add % mod + 3 * op.sum1 % mod * add % mod * add % mod) % mod;
op.sum2 = (op.sum2 + 2 * add % mod * op.sum1 % mod + len * add % mod * add % mod) % mod;
op.sum1 = (op.sum1 + len * add % mod) % mod;
op.add = (op.add + add) % mod;
}
}
void down(int u) {
eval(tr[u << 1], tr[u].add, tr[u].mul, tr[u].tob), eval(tr[u << 1 | 1], tr[u].add, tr[u].mul, tr[u].tob);
tr[u].add = 0, tr[u].mul = 1, tr[u].tob = 0;
}
void build(int u, int l, int r) {
if (l == r) {
tr[u] = {l, r, 0, 0, 0, 0, 1, 0};
return ;
}
tr[u] = {l, r, 0, 0, 0, 0, 1, 0};
int mid = l + r >> 1;
build (u << 1, l, mid), build(u << 1 | 1, mid + 1, r);
}
void modify(int u, int l, int r, int x, int y, int op) {
if (l <= tr[u].l && tr[u].r <= r) {
eval(tr[u], x, y, op);
return ;
}
down(u);
int mid = tr[u].r + tr[u].l >> 1;
if (l <= mid) modify(u << 1, l, r, x, y, op);
if (r > mid) modify(u << 1 | 1, l, r, x, y, op);
pushup(u);
}
LL query(int u, int l, int r, int c) {
if (l <= tr[u].l && tr[u].r <= r) {
if (c == 1) return tr[u].sum1;
else if (c == 2) return tr[u].sum2;
else return tr[u].sum3;
}
down(u);
int mid = tr[u].l + tr[u].r >> 1;
LL res = 0;
if (l <= mid) res = query(u << 1, l, r, c);
if (r > mid) res = (res + query(u << 1 | 1, l, r, c)) % mod;
return res % mod;
}
int main() {
ios::sync_with_stdio(false), cin.tie(0);
while (cin >> n >> m, n | m) {
build(1, 1, n);
while (m --) {
int op, x, y, c;
cin >> op >> x >> y >> c;
if (op == 1) modify(1, x, y, c, 1, 0);
else if (op == 2) modify(1, x, y, 0, c, 0);
else if (op == 3) modify(1, x, y, 0, 1, c);
else cout << query(1, x, y, c) << endl;
}
}
return 0;
}