线段树
线段树是什么?
线段树是一颗二叉搜索树,对于一个线段,用二叉树的形式表示。它将一个区间划分成一些单元区间,每个单元区间对应线段树中的一个叶结点。
线段树的作用
用来解决区间问题
最基本的用来单点修改,区间修改,单点查询,区间查询。
其中两个经典问题是维护某段长度的最大值最小值,某段区间和。
然后可以引申出求区间最大值和最小值之差。
线段树的操作
线段树有五个操作,pushup,pushdown,build,query,modify
pushup 操作
这个操作的意思就是向上维护,如果修改了某一点或者某一区间的值,这时候必然要将它的父节点更新。
pushdown 操作
这个操作是加了懒标记以后出现的操作,一般加在查询和修改操作中,这个操作的目的是将懒标记传到子节点上,如果没有这个操作, 会导致pushup操作出现错误。
build 操作
建树操作。
query 操作 和 modify 操作
查询和修改操作。
[L, R]是要查询的区间,[Tl, Tr]是树的一个结点范围。
(1)如果[Tl, Tr]在[L, R]范围内,即L < Tl < Tr < R,此时直接返回区间值。
(2)如果[Tl, Tr]和[L, R]交集不为空
①如果Tl < L < Tr < R
mid = Tl + Tr >> 1,如果L > mid 只递归右边,如果L < mid 同时递归左边和右边。
②L < Tl < R < Tr,和第一种完全相反。
③如果[L, R]在[Tl, Tr]范围内
mid = Tl + Tr >> 1,如果R <= mid,只递归左边。如果L > mid,只递归右边。 其他情况(L < mid < R)同时递归左边递归右边。
线段树开4 * n空间解释
时间复杂度
为什么要加懒标记?
如果每次插入操作是在一条线段上每个位置均加k,而查询操作是计算一条线段上的总和,那么在结点上需要记录的值为sum。
这里会遇到一个问题:为了使所有sum值都保持正确,每一次插入操作可能要更新O(N)个sum值,从而使时间复杂度退化为O(N)。
解决方案是Lazy思想:对整个结点进行的操作,先在结点上做标记,而并非真正执行,直到根据查询操作的需要分成两部分。
关于懒标记:
懒标记解释
以处理区间和为例
不带懒标记
#include <algorithm>
#include <cstdio>
#include <cstring>
#include <iostream>
#include <map>
#include <queue>
#include <set>
#include <stack>
#include <vector>
//#include <tr1/unordered_map>
#include <bits/stdc++.h>
#include <cmath>
#include <unordered_map>
using namespace std;
#define sfd(i) scanf("%d", &i)
#define sfl(i) scanf("%I64d", &i)
#define sfs(i) scanf("%s", (i))
#define prd(i) printf("%d\n", i)
#define prl(i) printf("%I64d\n", i)
#define sff(i) scanf("%lf", &i)
#define ll long long
#define ull unsigned long long
#define uint unsigned int
#define mst(x, y) memset(x, y, sizeof(x))
#define INF 0x3f3f3f3f
#define inf 8e19
#define eps 1e-10
const int maxn = 3e5;
#define PI acos(-1.0)
#define lowbit(x) ((x) & (-x))
#define fl() printf("flag\n")
#define MOD(x) ((x % mod) + mod) % mod
#define endl '\n'
#define pb push_back
#define lson (rt << 1)
#define rson (rt << 1 | 1)
#define FAST_IO \
ios::sync_with_stdio(false); \
cin.tie(0); \
cout.tie(0)
const int N = 50010;
struct node
{
int l, r;
int sum;
}tr[N * 4];
int w[N * 4];
void pushup(int u)
{
tr[u].sum = tr[u << 1].sum + tr[u << 1 | 1].sum;
}
void build(int u, int l, int r)
{
if(l == r) tr[u] = {l, r, w[r]};
else
{
tr[u] = {l, r};
int mid = l + r >> 1;
build(u << 1, l, mid);
build(u << 1 | 1, mid + 1, r);
pushup(u);
}
}
void modify(int u, int pos, int w)
{
if(tr[u].l == tr[u].r) tr[u].sum += w;
else
{
int mid = tr[u].l + tr[u].r >> 1;
if(pos <= mid) modify(u << 1, pos, w);
if(pos > mid) modify(u << 1 | 1, pos, w);
pushup(u);
}
}
int query(int u, int l, int r)
{
int res = 0;
if(tr[u].l >= l && tr[u].r <= r) return tr[u].sum;
int mid = tr[u].l + tr[u].r >> 1;
if(l <= mid) res = query(u << 1, l, r);
if(r > mid) res += query(u << 1 | 1, l, r);
return res;
}
int main() {
int t, n, a, b;
sfd(t);
char s[10];
int Case = 0;
while (t--) {
sfd(n);
for (int i = 1; i <= n; i++) sfd(w[i]);
build(1, 1, n);
printf("Case %d:\n", ++Case);
while (sfs(s) && strcmp(s, "End")) {
if (strcmp(s, "Query") == 0) {
sfd(a), sfd(b);
int c = query(1, a, b);
cout << c << endl;
} else if (strcmp(s, "Add") == 0) {
sfd(a), sfd(b);
modify(1, a, b);
} else if (strcmp(s, "Sub") == 0) {
sfd(a), sfd(b);
modify(1, a, -b);
}
}
}
return 0;
}
带懒标记
#include <bits/stdc++.h>
using namespace std;
const int N = 1e6 + 10;
int n, m;
int w[N];
typedef long long ll;
struct node
{
int l, r;
int sum, add;
}tr[N * 4];
void pushup(int u)
{
tr[u].sum = tr[u << 1].sum + tr[u << 1 | 1].sum;
}
void pushdown(int u)
{
if(tr[u].add)
{
tr[u << 1].sum += (tr[u << 1].r - tr[u << 1].l + 1) * tr[u].add;
tr[u << 1 | 1].sum += (tr[u << 1 | 1].r - tr[u << 1 | 1].l + 1) * tr[u].add;
tr[u << 1].add += tr[u].add;
tr[u << 1 | 1].add += tr[u].add;
tr[u].add = 0;
}
}
void build(int u, int l, int r)
{
if(l == r) tr[u] = {l, r, w[r], 0};
else
{
tr[u] = {l, r};
int mid = l + r >> 1;
build(u << 1, l, mid);
build(u << 1 | 1, mid + 1, r);
pushup(u);
}
}
void modify(int u, int l, int r, int d)
{
if(tr[u].l >= l && tr[u].r <= r)
{
tr[u].add += d;
tr[u].sum += d * (tr[u].r - tr[u].l + 1);
}
else
{
pushdown(u);
int mid = tr[u].l + tr[u].r >> 1;
if(l <= mid) modify(u << 1, l, r, d);
if(r > mid) modify(u << 1 | 1, l, r, d);
pushup(u);
}
}
int query(int u, int l, int r)
{
if(tr[u].l >= l && tr[u].r <= r) return tr[u].sum;
pushdown(u);
int sum = 0;
int mid = tr[u].l + tr[u].r >> 1;
if(l <= mid) sum = query(u << 1, l, r);
if(r > mid) sum += query(u << 1 | 1, l, r);
return sum;
}
int main()
{
scanf("%d%d", &n, &m);
for(int i = 1; i <= n; i++) cin >> w[i];
build(1, 1, n);
for(int i = 0; i < m; i++)
{
char op[2];
int a, b, c;
cin >> op;
if(op[0] == 'Q')
{
cin >> a >> b;
cout << query(1, a, b) << endl;
}
else if(op[0] == 'C')
{
cin >> a >> b >> c;
modify(1, a, b, c);
}
}
return 0;
}