#define _CRT_SECURE_NO_WARNINGS
#include<iostream>
#include<algorithm>
#include<vector>
using namespace std;
const int maxn = 5e5 + 5;//4倍大小
const int INF = 1e8;
struct segment_tree
{
int setv[maxn];
int maxv[maxn]; int minv[maxn]; int sumv[maxn];
int _max, _sum, _min;
int v; int y1, y2; int n;
void dfs(int o, int L, int R,int * arr)
{
if (L == R)sumv[o] = minv[o] = maxv[o] = arr[L];//这里可以不用设置setv[o]
else
{
int m = L + (R - L) / 2;
dfs(o * 2, L, m, arr);
dfs(o * 2 + 1, m + 1, R, arr);
sumv[o] = sumv[o * 2] + sumv[o * 2 + 1];
maxv[o] = max(maxv[o * 2], maxv[o * 2 + 1]);
minv[o] = min(minv[o * 2], minv[o * 2 + 1]);
}
}
void build(int* arr, int n)
{
memset(setv, - 1, sizeof(setv));//初始化为-1
dfs(1, 0, n - 1, arr); this->n = n;
}
void maintain(int o, int L, int R)
{
if (setv[o] >= 0)
{
maxv[o] = setv[o];
minv[o] = setv[o];
sumv[o] = setv[o];
}
}
void pushdown(int o, int L, int R)
{
if (setv[o] >= 0) //很重要,如果<0(即没被set),那么Pushdown就错了
{
int t = setv[o];
setv[o * 2] = t;
setv[o * 2 + 1] = t;
setv[o] = -1;
}
}
void update(int o, int L, int R)
{
if (y1 <= L && R <= y2)
{
setv[o] = v;
}
else
{
pushdown(o, L, R);
int m = L + (R - L) / 2;
if (y1 <= m)update(o * 2, L, m); else maintain(o * 2, L, R);
if (m < y2)update(o * 2 + 1, m + 1, R); else maintain(o * 2 + 1, m + 1, R);
}
maintain(o, L, R);
}
void query(int o, int L, int R)
{
if (setv[o] >= 0)
{
_sum += setv[o] * (R - L + 1);
_max = max(_max, setv[o]);
_min = min(_min, setv[o]);
}
else if (y1 <= L && R <= y2)
{
_sum += sumv[o];
_max = max(_max, maxv[o]);
_min = min(_min, minv[o]);
}
else
{
int m = L + (R - L) / 2;
if (y1 <= m)query(o * 2, L, m);
if (y2 > m)query(o * 2 + 1, m + 1, R);
}
}
void interval_set(int a, int b, int v)//[a,b]元素都为v
{
y1 = a; y2 = b; this->v = v;
update(1, 0, n - 1);
}
void interval_query(int a, int b)
{
_sum = 0; _min = INF; _max = -INF;
y1 = a; y2 = b;
query(1, 0, n - 1);
}
}t;
int arr[maxn] = { 1,2,3,4,5,6,7,8 };
int main()
{
t.build(arr, 8);
t.interval_query(0, 3);
cout << t._max << " " << t._min << " " << t._sum << endl;
t.interval_set(0, 3, 1);
t.interval_query(0, 3);
cout << t._max << " " << t._min << " " << t._sum << endl;
return 0;
}
区间元素set为v,区间元素和、最值
默认setv[o]>=0才表示set了
set与add不同,set要考虑先后顺序。因此set需要pushdown操作,并且查询时遇到setv[o]>=0就不再往下递归了(下面可能还有setv,但理应被这里的覆盖掉)
注意update时多了2个maintain操作,因为修改了setv之后肯定得maintain。