题意:给你一个序列,需要支持以下操作:1:区间内的所有数加上某个值。2:区间内的所有数除以某个数(向下取整)。3:询问某个区间内的最大值。
思路(从未见过的套路):维护区间最大值和区间最小值,执行2操作时,继续向下寻找子区间,如果区间满足:min - (min / x) == max - (max / x)时,给这个区间内的所有数减去min - (min / x)就可以了。为什么这样做呢?因为向下取整操作变化速度远快于加法,在经过很多次操作后其实序列中的数区域相等,复杂度需要用势能分析之类的,均摊复杂度应该是O(n * (log(n) ^ 2))。
代码:
#include <bits/stdc++.h>
#define LL long long
#define ls (o << 1)
#define rs (o << 1 | 1)
using namespace std;
const int maxn = 200010;
struct Seg {
LL add, mx, mi;
};
Seg tr[maxn * 4];
LL a[maxn];
void pushup(int o) {
tr[o].mx = max(tr[ls].mx, tr[rs].mx);
tr[o].mi = min(tr[ls].mi, tr[rs].mi);
}
void pushdown(int o) {
if(tr[o].add != 0) {
tr[ls].add += tr[o].add;
tr[ls].mi += tr[o].add;
tr[ls].mx += tr[o].add;
tr[rs].add += tr[o].add;
tr[rs].mi += tr[o].add;
tr[rs].mx += tr[o].add;
tr[o].add = 0;
}
}
void dfs(int o, int l, int r, LL val) {
if(tr[o].mi - (tr[o].mi / val) == tr[o].mx - (tr[o].mx / val)) {
LL tmp = tr[o].mi - (tr[o].mi / val);
tr[o].add -= tmp;
tr[o].mi -= tmp;
tr[o].mx -= tmp;
return;
}
int mid = (l + r) >> 1;
pushdown(o);
dfs(ls, l, mid, val);
dfs(rs, mid + 1, r, val);
pushup(o);
}
void build(int o, int l, int r) {
if(l == r) {
tr[o].add = 0;
tr[o].mx = tr[o].mi = a[l];
return;
}
int mid = (l + r) >> 1;
build(ls, l, mid);
build(rs, mid + 1, r);
pushup(o);
}
void update(int o, int l, int r, int ql, int qr, LL val, bool flag) {
if(l >= ql && r <= qr) {
if(flag == 0) {
tr[o].mi += val;
tr[o].mx += val;
tr[o].add += val;
} else {
dfs(o, l, r, val);
}
return;
}
pushdown(o);
int mid = (l + r) >> 1;
if(ql <= mid) update(ls, l, mid, ql, qr, val, flag);
if(qr > mid) update(rs, mid + 1, r, ql, qr, val, flag);
pushup(o);
}
LL query(int o, int l, int r, int ql, int qr) {
if(l >= ql && r <= qr) {
return tr[o].mx;
}
int mid = (l + r) >> 1;
LL ans = 0;
pushdown(o);
if(ql <= mid) ans = max(ans, query(ls, l, mid, ql, qr));
if(qr > mid) ans = max(ans, query(rs, mid + 1, r, ql, qr));
return ans;
}
int main() {
int op, l, r, x, n, m;
scanf("%d%d", &n, &m);
for (int i = 1; i <= n; i++) {
scanf("%lld", &a[i]);
}
build(1, 1, n);
for (int i = 1; i <= m; i++) {
scanf("%d", &op);
if(op == 0) {
scanf("%d%d%d", &l, &r, &x);
l++, r++;
update(1, 1, n, l, r, x, 0);
} else if(op == 1) {
scanf("%d%d%d", &l, &r, &x);
l++, r++;
if(x != 1)
update(1, 1, n, l, r, x, 1);
} else {
scanf("%d%d%d", &l, &r, &x);
l++, r++;
printf("%lld\n", query(1, 1, n, l, r));
}
}
}