常规的线段树可以用lazy标志来实现线段树的区间更新(区间覆盖,区间加减定值等),但是形如下面操作1却不是很好处理
- For all , change Ai to min(Ai, x)
- Query for the sum of Ai in [l, r]
可以参考2016年国家集训队论文集中的“区间最值与历史最值问题”——吉如一,关于求区间和、区间最值的问题可以用常规的线段树轻松解决。关键是对于操作1,如何去更新的问题。
这里,我采用的是记录每个线段树节点所管辖的A数组区间,显然,管辖区间长度等于当前节点的所有子孙节点中叶子节点的个数。当从一个节点往下进行更新的时候,如果当前节点记录的最大值小于等于需要更新的值x,可以直接return,因为该节点下的所有叶节点的值都不会大于x,不会有更新操作。反之,从该节点往下进行dfs操作,同理,dfs时遇到的节点中记录最大值若小于等于x,也可以直接return(子节点管辖区间是其父节点所管辖区间的子区间),否则,将该节点“删除”(所有记录值清0),并递归向下进行该操作。dfs完成之后,对dfs的初始节点进行更新(update)操作,显然,该节点下的最大值将被更新为x,然后将该节点的sum值加上dfs中“删除”的叶节点数与x的乘积,此时我们仅仅更新了当前一个节点,而其所有子孙节点都没有更新,这其实就和线段树区间更新的lazy操作是一个原理,然后令该节点的标记值tag=x,x为更新值,这样,当查询该节点的子孙节点时我们从该节点一层一层往下pushDown,就相当于把原来“删除”的节点又重建起来了,由于我们不必每次全部更新到叶子节点,其实这个方法是很高效的。
其实就是线段树递归中的思想,能return了(线段树节点所管辖的区间的是查询区间的子区间或者这两个区间不交,即不包含公共元素)立刻return。在dfs和维护线段树时我们也是能return(当前线段树节点下所有叶节点的值都不大于比较值x,这样也就不需要继续往下搜索了)就return。同理,如果更新操作的是保留较大的元素则搜索时只需要判断最小值了。
// http://acm.hdu.edu.cn/showproblem.php?pid=5306
/**
* 0, l, r, val a_i = min(a_i, val) i in [l, r]
* 1, l, r, query max a_i for i in [l, r]
* 2, l, r, query sum a_i for i in [l, r]
**/
#include <bits/stdc++.h>
#define lson(u) (u << 1)
#define rson(u) (u << 1 | 1)
using namespace std;
const int MAXN = 1000000 + 7;
long long a[MAXN];
struct SegmentTreeNode {
int l, r; // 该节点管辖的原序列区间
long long _max, sum, cnt, tag; // 最大值,和,无需更新(pushDown)的叶节点,标记值(往下pushDown更新)
} node[MAXN << 2];
inline void pushUp(int u) {
node[u].sum = node[lson(u)].sum + node[rson(u)].sum;
node[u].cnt = node[lson(u)].cnt + node[rson(u)].cnt;
node[u]._max = max(node[lson(u)]._max, node[rson(u)]._max);
}
void build(int l, int r, int u) {
node[u].tag = 0;
node[u].l = l;
node[u].r = r;
if(l == r) {
node[u].tag = node[u]._max = node[u].sum = a[l];
node[u].cnt = 1;
return ;
}
int m = (l + r) >> 1;
build(l, m, lson(u));
build(m + 1, r, rson(u));
pushUp(u);
return ;
}
inline void update(int u, long long alter) { // 更新节点值,每次只更新线段树中的一个节点
if(node[u].tag != 0 && node[u].tag <= alter) return ;
node[u].tag = alter;
if(node[u].cnt != node[u].r - node[u].l + 1) {
node[u]._max = alter;
node[u].sum += 1LL * (node[u].r - node[u].l + 1 - node[u].cnt) * alter;
node[u].cnt = node[u].r - node[u].l + 1;
}
return ;
}
inline void pushDown(int u) { // 往下pushDown一层更新
if(node[u].tag == 0) return ;
update(lson(u), node[u].tag);
update(rson(u), node[u].tag);
}
inline void dfs(int u, long long alter) { // 暴力递归“删除”需要更新路径上的所有节点
if(node[u]._max <= alter) return ;
node[u].tag = 0;
if(node[u].l == node[u].r) {
node[u]._max = node[u].sum = node[u].cnt = 0;
return ;
}
dfs(lson(u), alter);
dfs(rson(u), alter);
pushUp(u);
return ;
}
inline void modify(long long alter, int l, int r, int u) {
if(node[u]._max <= alter) return ;
if(node[u].r < l || node[u].l > r) return ;
if(l <= node[u].l && node[u].r <= r) {
dfs(u, alter);
update(u, alter); // 先“删除”该节点子孙所有需要“删除”的路径但只更新当前节点,同lazy原理
return ;
}
int m = (node[u].l + node[u].r) >> 1;
pushDown(u);
if(l <= m) {
modify(alter, l, r, lson(u));
}
if(m < r) {
modify(alter, l, r, rson(u));
}
pushUp(u);
return ;
}
long long querySum(int l, int r, int u) {
if(node[u].r < l || node[u].l > r) return 0;
if(l <= node[u].l && node[u].r <= r) {
return node[u].sum;
}
pushDown(u);
int m = (node[u].l + node[u].r) >> 1;
long long ret = 0;
if(l <= m) {
ret += querySum(l, r, lson(u));
}
if(m < r) {
ret += querySum(l, r, rson(u));
}
pushUp(u);
return ret;
}
long long queryMax(int l, int r, int u) {
if(node[u].l > r || node[u].r < l) return -1;
if(l <= node[u].l && node[u].r <= r) {
return node[u]._max;
}
pushDown(u);
int m = (node[u].l + node[u].r) >> 1;
long long ret = -1;
if(l <= m) {
ret = max(ret, queryMax(l, r, lson(u)));
}
if(m < r) {
ret = max(ret, queryMax(l, r, rson(u)));
}
pushUp(u);
return ret;
}
int main() {
int T;
scanf("%d", &T);
while(T--) {
int n, m;
scanf("%d %d", &n, &m);
for(int i = 1; i <= n; ++i) {
scanf("%lld", a + i);
}
build(1, n, 1);
while(m--) {
int op, l, r;
scanf("%d %d %d", &op, &l, &r);
if(op == 0) {
long long alter;
scanf("%lld", &alter);
modify(alter, l, r, 1);
} else if(op == 1) {
long long ret = queryMax(l, r, 1);
printf("%lld\n", ret);
} else {
long long ret = querySum(l, r, 1);
printf("%lld\n", ret);
}
}
}
return 0;
}