1. 线段树原理
基于分治思想的二叉树,叶子节点存储元素,非叶子节点存储统计值,如区间和,区间最值等。维护和查询树的时间复杂度为树的高度logn。
特点:
- 根从1开始时:左儿子=根x2,右儿子=根x2+1
- 左右节点相等时为叶子节点,存储元素值,其它节点存储统计值。
- 需要增加节点保存统计值,因此需要开4倍元素个数的空间(证明略)
- 以根节点进行区间分治,因此时间复杂度为树的高度logn。
2. 线段树应用
线段树是一种用于维护区间信息的高级数据结构,用于计算区间和,区间最值类问题。修改和查询可以在logn时间内完成。
一般应用场景有:
单点修改,区间查询
区间修改,区间查询
3. 线段树C++实现
1. 点修区查
#include<algorithm>
#include<iostream>
using namespace std;
#define ll long long
#define lc p<<1
#define rc p<<1|1
const int N = 200001;
ll a[N];
struct SegTree {
ll l, r, sum;
}tr[N<<2];
//建树
void Build(ll p, ll l, ll r) {
tr[p] = { l,r,a[l] };
if (l == r) return;
ll m = l + r >> 1;
Build(lc, l, m);
Build(rc, m + 1, r);
tr[p].sum = tr[lc].sum + tr[rc].sum;
}
//单点修改
void change(ll p, ll x, ll v) {
if (tr[p].l == x && tr[p].r == x) {
tr[p].sum += v;
return;
}
ll m = tr[p].l + tr[p].r >> 1;
if (x <= m) change(lc, x, v);
else change(rc, x, v);
tr[p].sum = tr[lc].sum + tr[rc].sum;
}
//区间查询
ll query(ll p, ll x, ll y) {
if (x <= tr[p].l && tr[p].r <= y) return tr[p].sum;
ll m = tr[p].l + tr[p].r >> 1;
ll sum = 0;
if (x <= m) sum += query(lc, x, y);
if (y > m) sum += query(rc, x, y);
return sum;
}
int main() {
int n, m, op, x, y, k;
cin >> n >> m;
for (int i = 1; i <= n; i++) cin >> a[i];
Build(1, 1, n);
while (m--) {
cin >> op >> x >> y;
if (op == 1) {
cin >> k;
change(1, x, k);
}
else {
cout << query(1, x, y) << endl;
}
}
return 0;
}
2. 区修区查
区修时需要引入懒标记技术,否则时间复杂度很高。
#include<iostream>
#include<algorithm>
using namespace std;
#define N 200001
#define lc p<<1
#define rc p<<1|1
int a[N];
struct SegTree {
int l, r, sum, add;//懒标记
}tr[N<<2];
void pushup(int p) {
tr[p].sum = tr[lc].sum + tr[rc].sum;
}
//建树
void Build(int p, int l, int r) {
tr[p] = { l,r,a[l],0 };
if (l == r) return;
int m = l + r >> 1;
Build(lc, l, m);
Build(rc, m + 1, r);
pushup(p);
}
//下传懒标记
void pushdown(int p) {
int t = tr[p].add;
if (t) {
tr[lc].sum += t * (tr[lc].r - tr[lc].l + 1);
tr[rc].sum += t * (tr[rc].r - tr[rc].l + 1);
tr[lc].add += t;
tr[rc].add += t;
tr[p].add = 0;
}
}
//区修
void change(int p, int x, int y, int k) {
if (x <= tr[p].l && tr[p].r <= y) {
tr[p].sum += (tr[p].r - tr[p].l + 1)* k;
tr[p].add += k;
return;
}
int m = tr[p].l + tr[p].r >> 1;
pushdown(p);
if (x <= m) change(lc, x, y, k);
if (y > m) change(rc, x, y, k);
pushup(p);
}
//区查
int query(int p, int x, int y) {
if (x <= tr[p].l && tr[p].r <= y) {
return tr[p].sum;
}
int m = tr[p].l + tr[p].r >> 1;
pushdown(p);
int sum = 0;
if (x <= m) sum += query(lc, x, y);
if (y > m) sum += query(rc, x, y);
return sum;
}
int main() {
int n, m, op, x, y, k;
cin >> n >> m;
for (int i = 1; i <= n; i++) cin >> a[i];
Build(1, 1, n);
while (m--) {
cin >> op >> x >> y;
if (op == 1) {
cin >> k;
change(1, x, y, k);
}
else {
cout << query(1, x, y) << endl;
}
}
return 0;
}