上次写线段树,已经是很早之前了,时间久了容易忘,特此重写一次复习一遍
// 线段树模板
#include <bits/stdc++.h>
#pragma warning (disable:4996)
#pragma warning (disable:6031)
#define mem(a, b) memset(a, b, sizeof a)
#define debug puts("--------------------------------")
using namespace std;
struct p {
int l, r, val, max, min;
int mark;
};
const int N = 310;
p t[N * 4];// 线段树需要四倍空间
void build(int l, int r, int k) {
t[k].l = l;
t[k].r = r;
if (l == r) {
//scanf("%d", &t[k].val);
t[k].val = 1;
t[k].max = t[k].min = t[k].val;
return;
}
int mid = (l + r) / 2;
build(l, mid, k * 2);
build(mid + 1, r, k * 2 + 1);
t[k].max = max(t[k * 2].max, t[k * 2 + 1].max);
t[k].min = min(t[k * 2].min, t[k * 2 + 1].min);
t[k].val = t[k * 2 + 1].val + t[k * 2].val;
}
void down(int k) {
if (t[k].mark) {
int temp = t[k].mark;
t[k * 2].max += temp;
t[k * 2].min += temp;
t[k * 2].val += temp * (t[k * 2].r - t[k * 2].l + 1);
t[k * 2 + 1].max += temp;
t[k * 2 + 1].min += temp;
t[k * 2 + 1].val += temp * (t[k * 2 + 1].r - t[k * 2 + 1].l + 1);
t[k * 2].mark += t[k].mark;
t[k * 2 + 1].mark += t[k].mark;
t[k].mark = 0;
}
}
void update(int l, int r, int k, int val) {
int ll = t[k].l;
int rr = t[k].r;
if (ll >= l && rr <= r) {
t[k].max += val;
t[k].min += val;
t[k].val += val * (t[k].r - t[k].l + 1);
t[k].mark += val;
}
else if (ll > r || rr < l) {
return;
}
else {
down(k);
int mid = (ll + rr) / 2;
update(l, mid, k * 2, val);
update(mid + 1, r, k * 2 + 1, val);
t[k].max = max(t[k * 2].max, t[k * 2 + 1].max);
t[k].min = min(t[k * 2].min, t[k * 2 + 1].min);
t[k].val += (t[k].r - t[k].l + 1) * val;
}
}
int query(int x, int k) {
if (t[k].l == t[k].r) {
return t[k].val;
}
down(k);
int mid = (t[k].l + t[k].r) / 2;
if (x <= mid)return query(x, k * 2);
else return query(x, k * 2 + 1);
}
int query(int l, int r, int k) {
// 单点查询可以看作是区间的左右端点相同
// PS:[2, 2]
int ll = t[k].l;
int rr = t[k].r;
if (l <= ll && r >= rr) {
return t[k].val;
}
else if (r < ll || rr < l) {
return 0;
// 因为是以区间和为例子,所以返回0
// 这个视情况而定
// 如果查询的是区间最大值,就返回一个很小的值(不影响结果的值)
}
else {
down(k);
return query(l, r, k * 2) + query(l, r, k * 2 + 1);
}
}
void show() {
for (int i = 1; i <= 5; i++) {
printf("%d ", query(i, i, 1));
}
puts("");
debug;
}
int main()
{
mem(t, 0);
build(1, 5, 1);
show();
update(1, 4, 1, 1);
show();
update(2, 5, 1, -1);
show();
update(3, 3, 1, 2);
show();
return 0;
}