Description
给定序列 a = ( a 1 , a 2 , ⋯ , a n ) a=(a_1,a_2,\cdots,a_n) a=(a1,a2,⋯,an),有 m m m 个操作分三种:
- add ( l , r , k ) \operatorname{add}(l,r,k) add(l,r,k):对每个 i ∈ [ l , r ] i\in[l,r] i∈[l,r] 执行 a i ← a i + k a_i\gets a_i+k ai←ai+k.
- sqrt ( l , r ) \operatorname{sqrt}(l,r) sqrt(l,r):对每个 i ∈ [ l , r ] i\in[l,r] i∈[l,r] 执行 a i ← ⌊ a i ⌋ a_i\gets \lfloor \sqrt {a_i} \rfloor ai←⌊ai⌋.
- query ( l , r ) \operatorname{query}(l,r) query(l,r):求 ∑ i = l r a i \sum_{i=l}^r a_i ∑i=lrai.
Limitations
1
≤
n
,
m
≤
1
0
5
1\le n,m\le 10^5
1≤n,m≤105
1
≤
l
≤
r
≤
n
1\le l\le r\le n
1≤l≤r≤n
1
≤
a
i
,
k
≤
1
0
5
1\le a_i,k\le 10^5
1≤ai,k≤105
Solution
考虑
sqrt
\operatorname{sqrt}
sqrt 操作怎么做.
显然当区间极差为
0
0
0 时,直接对整个区间开方.
但是有特殊情况
⌊
x
⌋
+
1
=
x
+
1
\lfloor \sqrt x\rfloor + 1= \sqrt{x+1}
⌊x⌋+1=x+1,例如:
原序列为
(
65535
,
65536
,
65535
,
65536
,
⋯
)
(65535,65536,65535,65536,\cdots)
(65535,65536,65535,65536,⋯).
开三次平方后变成
(
3
,
4
,
3
,
4
,
⋯
)
(3,4,3,4,\cdots)
(3,4,3,4,⋯),每次都是最坏复杂度.
然后可以
add
\operatorname{add}
add 回去,于是就被卡成了
O
(
n
m
)
O(nm)
O(nm).
不过很好解决,若开方前后序列极差均为
1
1
1,那也直接进行开方.
Code
3.3 KB , 637 ms , 11.57 MB (in total, C++20) 3.3\text{KB},637\text{ms},11.57\text{MB}\;\texttt{(in total, C++20)} 3.3KB,637ms,11.57MB(in total, C++20)
#include <bits/stdc++.h>
using namespace std;
using i64 = long long;
using ui64 = unsigned long long;
using i128 = __int128;
using ui128 = unsigned __int128;
using f4 = float;
using f8 = double;
using f16 = long double;
template<class T>
bool chmax(T &a, const T &b){
if(a < b){ a = b; return true; }
return false;
}
template<class T>
bool chmin(T &a, const T &b){
if(a > b){ a = b; return true; }
return false;
}
namespace seg_tree {
struct Node {
int l, r;
i64 max, min, sum, tag;
};
inline int ls(int u) { return 2 * u + 1; }
inline int rs(int u) { return 2 * u + 2; }
struct SegTree {
vector<Node> tr;
inline SegTree() {}
inline SegTree(const vector<i64>& a) {
const int n = a.size();
tr.resize(n << 1);
build(0, 0, n - 1, a);
}
inline void pushup(int u, int mid) {
tr[u].sum = tr[ls(mid)].sum + tr[rs(mid)].sum;
tr[u].max = max(tr[ls(mid)].max, tr[rs(mid)].max);
tr[u].min = min(tr[ls(mid)].min, tr[rs(mid)].min);
}
inline void apply(int u, i64 tag) {
tr[u].sum += tag * (tr[u].r - tr[u].l + 1);
tr[u].min += tag;
tr[u].max += tag;
tr[u].tag += tag;
}
inline void pushdown(int u, int mid) {
if (tr[u].tag) {
apply(ls(mid), tr[u].tag);
apply(rs(mid), tr[u].tag);
tr[u].tag = 0;
}
}
void build(int u, int l, int r, const vector<i64>& a) {
tr[u].l = l, tr[u].r = r;
if (l == r) return (void)(tr[u].sum = tr[u].min = tr[u].max = a[l]);
const int mid = (l + r) >> 1;
build(ls(mid), l, mid, a);
build(rs(mid), mid + 1, r, a);
pushup(u, mid);
}
void sqrt(int u, int l, int r) {
if (tr[u].max == 1) return;
if (l <= tr[u].l && tr[u].r <= r) {
i64 tmin = std::sqrt(tr[u].min), tmax = std::sqrt(tr[u].max);
if (tr[u].min == tr[u].max ||
(tmin + 1 == tmax && tr[u].min + 1 == tr[u].max)) {
return apply(u, tmax - tr[u].max);
}
}
const int mid = (tr[u].l + tr[u].r) >> 1;
pushdown(u, mid);
if (l <= mid) sqrt(ls(mid), l, r);
if (r > mid) sqrt(rs(mid), l, r);
pushup(u, mid);
}
void add(int u, int l, int r, i64 k) {
if (l <= tr[u].l && tr[u].r <= r) return apply(u, k);
const int mid = (tr[u].l + tr[u].r) >> 1;
pushdown(u, mid);
if (l <= mid) add(ls(mid), l, r, k);
if (r > mid) add(rs(mid), l, r, k);
pushup(u, mid);
}
i64 query(int u, int l, int r) {
if (l <= tr[u].l && tr[u].r <= r) return tr[u].sum;
const int mid = (tr[u].l + tr[u].r) >> 1;
i64 res = 0;
pushdown(u, mid);
if (l <= mid) res += query(ls(mid), l, r);
if (r > mid) res += query(rs(mid), l, r);
return res;
}
inline void range_sqrt(int l, int r) { sqrt(0, l, r); }
inline void range_add(int l, int r, i64 v) { add(0, l, r, v); }
inline i64 range_sum(int l, int r) { return query(0, l, r); }
};
}
using seg_tree::SegTree;
signed main() {
ios::sync_with_stdio(0);
cin.tie(0), cout.tie(0);
int n, m; scanf("%d %d", &n, &m);
vector<i64> a(n);
for (int i = 0; i < n; i++) scanf("%lld", &a[i]);
SegTree sgt(a);
for (int i = 0, op, l, r, v; i < m; i++) {
scanf("%d %d %d", &op, &l, &r), l--, r--;
if (op == 1) {
scanf("%d", &v);
sgt.range_add(l, r, v);
}
else if (op == 2) sgt.range_sqrt(l, r);
else printf("%lld\n", sgt.range_sum(l, r));
}
return 0;
}