题意:
给定一个长度为 n n n 的数组 a a a,再有 m m m 次操作,① 1 , l , r , x 1, l, r, x 1,l,r,x,区间加 x x x;② 2 , l , r 2, l, r 2,l,r,区间开根;③ 3 , l , r 3, l, r 3,l,r,询问区间和。 ( n , m , a i , x ≤ 1 0 5 ) (n, m, a_i, x \leq 10^5) (n,m,ai,x≤105)
链接:
https://vjudge.net/problem/HDU-5828
解题思路:
做法与 https://blog.csdn.net/weixin_44059127/article/details/104978510 大同小异。由于区间加法操作的存在,不能直接暴力
d
f
s
dfs
dfs 到叶结点修改,那么考虑开根操作对区间极差的影响,直观上可以感觉到区间极差应也是
O
(
l
o
g
l
o
g
a
)
O(logloga)
O(logloga) 次后变为
0
0
0。设
y
1
=
⌊
x
1
⌋
,
y
2
=
⌊
x
2
⌋
,
x
=
x
1
−
x
2
≥
0
,
y
=
y
1
−
y
2
≥
0
y_1 = \lfloor\sqrt{x_1}\rfloor, y_2 = \lfloor\sqrt{x_2}\rfloor, x = x_1 - x_2 \geq 0, y = y_1 - y_2 \geq 0
y1=⌊x1⌋,y2=⌊x2⌋,x=x1−x2≥0,y=y1−y2≥0。
{
x
1
=
y
1
2
+
r
1
,
0
≤
r
1
≤
2
y
1
x
2
=
y
2
2
+
r
2
,
0
≤
r
2
≤
2
y
2
\begin{cases} x_1 = y_1^2 + r_1, &0 \leq r_1 \leq 2y_1\\ x_2 = y_2^2 + r_2, & 0 \leq r_2 \leq 2y_2 \end{cases}
{x1=y12+r1,x2=y22+r2,0≤r1≤2y10≤r2≤2y2
作差,假设
y
≥
1
y \geq 1
y≥1,已知
y
2
≥
1
y_2 \geq 1
y2≥1,
x
=
y
(
y
1
+
y
2
)
+
(
r
1
−
r
2
)
⇒
x
−
y
(
y
+
y
2
)
≥
2
y
2
⇒
y
≤
⌈
x
+
3
⌉
−
1
\begin{matrix} & x = y(y_1 + y_2) + (r_1 - r_2) \\ ~ \\ \Rightarrow & x - y(y + y_2) \geq 2y_2 \\ ~ \\ \Rightarrow & y \leq \lceil\sqrt{x + 3}~\rceil - 1 \end{matrix}
⇒ ⇒x=y(y1+y2)+(r1−r2)x−y(y+y2)≥2y2y≤⌈x+3 ⌉−1
这仅是缩放得到的一个上界,当可以看到区间极差随着开根操作进行的递减趋势。当
x
>
2
x \gt 2
x>2,有
y
<
x
y \lt x
y<x,特殊的,可以知道当
x
=
1
x = 1
x=1,可能进行多次开根后仍使得
y
=
1
y = 1
y=1 并最终才变为
0
0
0,而从这个上界无法得知是否
x
=
2
x = 2
x=2 也出现这种情况,回代反证可知不存在
x
=
2
x = 2
x=2,
y
=
2
y = 2
y=2,这意味着当
x
=
y
x = y
x=y 时我们对区间开根操作可以直接更新返回(仅当
x
=
0
,
1
x = 0, 1
x=0,1 时可能成立,可以变成一次区间减法),否则才向下暴力,而区间加法不改变区间极差,那么复杂度得到保证。
综上,总时间复杂度为 O ( m l o g n + n l o g n l o g l o g a ) O(mlogn + nlognlogloga) O(mlogn+nlognlogloga)。
参考代码:
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
typedef pair<int, int> pii;
#define sz(a) ((int)a.size())
#define pb push_back
#define lson (rt << 1)
#define rson (rt << 1 | 1)
#define gmid (l + r >> 1)
const int maxn = 1e5 + 5;
const int inf = 0x3f3f3f3f;
const int mod = 1e9 + 7;
int a[maxn], sq[maxn];
int n, m;
inline ll mysqrt(ll x){
return x < maxn ? sq[x] : sqrt(x + 0.5);
}
struct SegTree{
ll sum[maxn << 2], mx[maxn << 2], mn[maxn << 2], add[maxn << 2];
void pushUp(int rt){
sum[rt] = sum[lson] + sum[rson];
mx[rt] = max(mx[lson], mx[rson]);
mn[rt] = min(mn[lson], mn[rson]);
}
void build(int l, int r, int rt){
add[rt] = 0;
if(l == r){
sum[rt] = mx[rt] = mn[rt] = a[l];
return;
}
int mid = gmid;
build(l, mid, lson);
build(mid + 1, r, rson);
pushUp(rt);
}
void pushDown2(int rt, int son, int len){
if(add[rt]){
add[son] += add[rt];
sum[son] += len * add[rt];
mx[son] += add[rt], mn[son] += add[rt];
}
}
void pushDown(int l, int r, int rt){
int mid = gmid;
pushDown2(rt, lson, mid - l + 1);
pushDown2(rt, rson, r - mid);
add[rt] = 0;
}
void update(int l, int r, int rt, int L, int R, int val){
if(l >= L && r <= R){
add[0] = val;
pushDown2(0, rt, r - l + 1);
return;
}
int mid = gmid;
pushDown(l, r, rt);
if(L <= mid) update(l, mid, lson, L, R, val);
if(R > mid) update(mid + 1, r, rson, L, R, val);
pushUp(rt);
}
void update(int l, int r, int rt, int L, int R){
if(l >= L && r <= R && mx[rt] - mn[rt] == mysqrt(mx[rt]) - mysqrt(mn[rt])){
add[0] = mysqrt(mx[rt]) - mx[rt];
pushDown2(0, rt, r - l + 1);
return;
}
int mid = gmid;
pushDown(l, r, rt);
if(L <= mid) update(l, mid, lson, L, R);
if(R > mid) update(mid + 1, r, rson, L, R);
pushUp(rt);
}
ll query(int l, int r, int rt, int L, int R){
if(l >= L && r <= R) return sum[rt];
int mid = gmid; ll ret = 0;
pushDown(l, r, rt);
if(L <= mid) ret += query(l, mid, lson, L, R);
if(R > mid) ret += query(mid + 1, r, rson, L, R);
return ret;
}
} tr;
const int maxs = 1e3 + 5;
char buf[maxs], *p1 = buf, *p2 = buf;
inline char fr(){
return p1 == p2 && (p2 = (p1 = buf) + fread(buf, 1, maxs, stdin)) == p1 ? -1 : *p1++;
}
#define gc fr()
inline void read(int &x){
char ch; while(!isdigit(ch = gc)); x = ch ^ 48;
while(isdigit(ch = gc)) x = x * 10 + (ch ^ 48);
}
int main(){
// ios::sync_with_stdio(0); cin.tie(0);
for(int i = 1; i < maxn; ++i){
sq[i] = (int)sqrt(i + 0.5);
}
int t; read(t);
while(t--){
read(n), read(m);
for(int i = 1; i <= n; ++i){
read(a[i]);
}
tr.build(1, n, 1);
while(m--){
int opt, x, y, z; read(opt), read(x), read(y);
if(opt == 1){
read(z);
tr.update(1, n, 1, x, y, z);
}
else if(opt == 2){
tr.update(1, n, 1, x, y);
}
else{
ll ret = tr.query(1, n, 1, x, y);
printf("%lld\n", ret);
}
}
}
return 0;;
}