线段树(Segment Tree) 是一个优雅的数据结构。
他能用O(logn)的时间复杂度实现 单点修改 区间查询 区间修改 ,但这些操作并非线段树的操作上限,通过维护不同的信息,线段树能在许多地方派上用场 ,是算法竞赛中常用的一种数据结构。
线段树的引入 (P3372 【模板】线段树 1)
题目描述
如题,已知一个数列,你需要进行下面两种操作:
1.将某区间每一个数加上x
2.求出某区间每一个数的和
输入格式
第一行包含两个整数N、M,分别表示该数列数字的个数和操作的总个数。
第二行包含N个用空格分隔的整数,其中第i个数字表示数列第i项的初始值。
接下来M行每行包含3或4个整数,表示一个操作,具体如下:
操作1: 格式:1 x y k 含义:将区间[x,y]内每个数加上k
操作2: 格式:2 x y 含义:输出区间[x,y]内每个数的和
输出格式
输出包含若干行整数,即为所有操作2的结果。
1.线段树的每一个节点是一个结构体,维护了一个区间内的各种信息。
2.线段树是平衡二叉树,每个节点u的左右子节点编号分别为 u<<1 和u<<1|1;
3.令mid = l+r >>1 ;两个子节点分别维护 [l,mid] 和 [mid+1,r] 的区间信息。
在本题中我们维护五个信息
struct Node
{
ll l, r; //区间的左端点和右端点
ll sum;//区间和
ll mk;//懒标记,在区间修改中发挥作用
int len;//区间长度
} tr[4 * maxn];
build:建立线段树
void push_up(int u)
{
tr[u].sum = tr[u << 1].sum + tr[u << 1 | 1].sum;
}
void build(int u, int l, int r)
{
if (l == r)
tr[u] = {l, r, w[r],0,1};
else
{
tr[u] = {l, r, 0, 0, r - l + 1};
int mid = l + r >> 1;
build(u << 1, l, mid);
build(u << 1 | 1, mid + 1, r);
push_up(u);
}
}
线段树的建立是一个递归的过程,从根节点向下遍历到叶子节点,然后向上更新每个节点的区间信息( push_up(int u) )
modify:单点修改
void modify(int u, int x, int v)
{
if (tr[u].l == tr[u].r)
tr[u].sum += v;
else
{
int mid = tr[u].l + tr[u].r >> 1;
if (x <= mid)
modify(u << 1, x, v);
else
modify(u << 1 | 1, x, v);
push_up(u);
}
}
算法思想如图,略过不表。
r_modify:*区间修改
下面的内容是线段树的精髓所在,先引入懒标记的概念,然后我们将看到通过懒标记我们能把线段树玩出花来。
对于区间修改,朴素的想法是用递归的方式一层层修改(类似于线段树的建立),但这样的时间复杂度比较高。使用懒标记后,对于那些正好是线段树节点的区间,我们不继续递归下去,而是打上一个标记,将来要用到它的子区间的时候,再向下传递。
回到第一张图,考虑如果我们在[3,5]区间内+v 所进行的操作。
第一种方法是,用递归的方式修改到叶子节点。
但是注意到遍历到节点3的时候 区间[4,5]已经被完全包含,并且我们知道区间[4,5]的长度len,因此我们可以直接对节点3所维护的区间和+ len*v 而不继续往下遍历。
在这个过程中,我们会标记未向下传递的v值,可能存在这样的多次操作,v值不断累积,直到我们必须用到他的子区间的时候,我们再取出v值,向下传递。
void push_down(int u, int len)
{
//标记向下传递。
tr[u << 1].mk += tr[u].mk;
tr[u << 1 | 1].mk += tr[u].mk;
//更新左右子节点
tr[u << 1].sum += tr[u].mk * (len - len / 2);
tr[u << 1 | 1].sum += tr[u].mk * (len / 2);
///清除标记
tr[u].mk = 0;
}
void r_modify(int u, int l, int r, ll v)
{
if (tr[u].l > r || tr[u].r < l) //区间若无交集直接退出
return;
else if (tr[u].l >= l && tr[u].r <= r)//区间被完全包含
{
//只更新当前节点的值,不继续向下遍历。
tr[u].sum += tr[u].len * v;
//如果当前节点非叶子节点,打上懒标记。
if (tr[u].r > tr[u].l)
tr[u].mk += v;
}
//若区间有部分交集。
else
{
//此时当前区间并不完全被目标区间所包含。
//因此必须向下遍历他的子节点
//更新子节点的值
//每次标记向下传递一层(想一想,为什么)
push_down(u, tr[u].len);
//递归地向左右子节点查找。
r_modify(u << 1, l, r, v);
r_modify(u << 1 | 1, l, r, v);
//记得更新。
push_up(u);
}
}
区间查询
和区间修改的方法类似,略过不表。
ll query(int u, int l, int r)
{
if (tr[u].l >= l && tr[u].r <= r)
return tr[u].sum;
ll mid = tr[u].l + tr[u].r >> 1;
push_down(u, tr[u].len);
ll sum = 0;
if (l <= mid)
sum += query(u << 1, l, r);
if (r > mid)
sum += query(u << 1 | 1, l, r);
return sum;
}
模板题AC代码
#include <bits/stdc++.h>
using namespace std;
const int maxn = 1e6 + 10;
typedef long long ll;
int n, m;
int w[maxn], mk[maxn];
struct Node
{
ll l, r;
ll sum;
ll mk;
int len;
} tr[4 * maxn];
//更新
void push_up(int u)
{
tr[u].sum = tr[u << 1].sum + tr[u << 1 | 1].sum;
}
void push_down(int u, int len)
{
tr[u << 1].mk += tr[u].mk;
tr[u << 1 | 1].mk += tr[u].mk;
tr[u << 1].sum += tr[u].mk * (len - len / 2);
tr[u << 1 | 1].sum += tr[u].mk * (len / 2);
tr[u].mk = 0;
}
//区间查询
ll query(int u, int l, int r)
{
if (tr[u].l >= l && tr[u].r <= r)
return tr[u].sum;
ll mid = tr[u].l + tr[u].r >> 1;
push_down(u, tr[u].len);
ll sum = 0;
if (l <= mid)
sum += query(u << 1, l, r);
if (r > mid)
sum += query(u << 1 | 1, l, r);
return sum;
}
//区间修改
void r_modify(int u, int l, int r, ll v)
{
if (tr[u].l > r || tr[u].r < l)
return;
else if (tr[u].l >= l && tr[u].r <= r)
{
tr[u].sum += tr[u].len * v;
if (tr[u].r > tr[u].l)
tr[u].mk += v;
}
else
{
push_down(u, tr[u].len);
r_modify(u << 1, l, r, v);
r_modify(u << 1 | 1, l, r, v);
push_up(u);
}
}
//单点修改
void modify(int u, int x, int v)
{
if (tr[u].l == tr[u].r)
tr[u].sum += v;
else
{
int mid = tr[u].l + tr[u].r >> 1;
if (x <= mid)
modify(u << 1, x, v);
else
modify(u << 1 | 1, x, v);
push_up(u);
}
}
void build(int u, int l, int r)
{
if (l == r)
tr[u] = {l, r, w[r],0,1};
else
{
tr[u] = {l, r, 0, 0, r - l + 1};
int mid = l + r >> 1;
build(u << 1, l, mid);
build(u << 1 | 1, mid + 1, r);
push_up(u);
}
}
int main()
{
scanf("%d%d", &n, &m);
for (int i = 1; i <= n; i++)
cin >> w[i];
build(1, 1, n);
while (m--)
{
int op;
cin >> op;
if (op == 1)
{
int l, r, v;
cin >> l >> r >> v;
r_modify(1, l, r, v);
}
else
{
int l, r;
cin >> l >> r;
cout << query(1, l, r) << endl;
}
}
system("pause");
return 0;
}