前言
线段树是一种实用的数据结构,在各种算法题里经常用到。但是假如每次使用 都需要重新写一遍线段树代码,不仅费时费力,还容易出错。为了能够轻松使用线段树,本人综合考虑了各类线段树的特点, 用C++的模板功能编写了一套线段树的类模板,并给出了其使用方法和解题案例。希望能给各位读者一点帮助。
线段树简介
线段树是一种二叉搜索树,与区间树相似,它将一个区间划分成一些单元区间,每个单元区间对应线段树中的一个叶结点。
对于线段树中的每一个非叶子节点
[
a
,
b
]
[a,b]
[a,b] ,它的左儿子表示的区间为
[
a
,
a
+
b
2
]
\left[a,\frac{a+b}{2}\right]
[a,2a+b] ,右儿子表示的区间为
[
a
+
b
2
+
1
,
b
]
\left[\frac{a+b}{2}+1,b\right]
[2a+b+1,b]。因此线段树是平衡二叉树,最后的子节点数目为
N
N
N,即整个线段区间的长度。
使用线段树可以快速的查找某一个节点在若干条线段中出现的次数,时间复杂度为
O
(
log
N
)
O(\log N)
O(logN)。而未优化的空间复杂度为
2
N
2N
2N,因此有时需要离散化让空间压缩。
线段树原理
源码分享
文件:s_tree.h
#ifndef S_TREE_H
#define S_TREE_H
#include <cstring>
#define ROOT 1 //根节点下标
template <typename Node, typename Label = int>
class STree {
public:
STree(int start, int end, Node* arr,
Node (*merge)(const Node& lchild, const Node& rchild),
void (*modify)(Node& node, const Label& value) = NULL);
virtual ~STree();
void modify(int loc, const Node& new_data); //单点修改
void modify(int start, int end, const Label& value); //区间修改
const Node query(int start, int end); //区间查询
const Node query(int loc); //单点查询
private:
Node* tree; //节点数组
Label* slack; //懒标记
const int length, loc_min, loc_max;
Node (*f_merge)(const Node& lchild, const Node& rchild); //区间合并函数
void (*f_modify)(Node& node, const Label& value); //区间修改函数
void build_tree(int ind, int left, int right, Node* arr);
Node query_in(int ind, int left, int right, int start, int end);
void modify_in(int ind, int left, int right, int start, int end,
const Label& value);
void down(int ind, int left, int right); //懒标记下沉,同时更新信息
void update(int ind); //更新信息
};
template <typename Node, typename Label>
STree<Node, Label>::STree(int start, int end, Node* arr,
Node (*merge)(const Node& lchild, const Node& rchild),
void (*modify)(Node& node, const Label& value))
: length(4 * (end - start + 1)), loc_min(start), loc_max(end) {
if (start > end) {
throw "STree constructor error";
}
tree = new Node[length];
slack = new Label[length];
f_merge = merge;
f_modify = modify;
std::memset(slack, 0, length * sizeof(Label));
build_tree(ROOT, start, end, arr);
}
template <typename Node, typename Label>
STree<Node, Label>::~STree() {
delete[] tree;
delete[] slack;
}
template <typename Node, typename Label>
void STree<Node, Label>::modify(int loc, const Node& new_data) {
if (loc < loc_min || loc > loc_max) {
throw "STree modify error";
}
int ind = ROOT;
int left = loc_min;
int right = loc_max;
while (left < right) {//向下寻找节点
if (slack[ind]) {
down(ind, left, right);
}
int mid = (left + right) >> 1;
if (loc <= mid) {
right = mid;
ind = ind << 1;
} else {
left = mid + 1;
ind = ind << 1 | 1;
}
}
tree[ind] = new_data;//修改节点
int ind_p = ind >> 1;
while (ind_p >= ROOT) {//向上更新
Node temp = tree[ind_p];
update(ind_p);
if (temp == tree[ind_p]) {
break;
}
ind_p = ind_p >> 1;
}
}
template <typename Node, typename Label>
void STree<Node, Label>::modify(int start, int end, const Label& value) {
if (start > end || start < loc_min || end > loc_max) {
throw "STree modify error";
}
modify_in(ROOT, loc_min, loc_max, start, end, value);//递归区间修改
}
template <typename Node, typename Label>
const Node STree<Node, Label>::query(int loc) {
return query_in(ROOT, loc_min, loc_max, loc, loc);//递归区间查询
}
template <typename Node, typename Label>
const Node STree<Node, Label>::query(int start, int end) {
if (start > end || start < loc_min || end > loc_max) {
throw "STree query error";
}
return query_in(ROOT, loc_min, loc_max, start, end);//递归区间查询
}
template <typename Node, typename Label>
void STree<Node, Label>::build_tree(int ind, int left, int right, Node* arr) {
if (left == right) {
tree[ind] = arr[left];
return;
}
int mid = (left + right) >> 1;
build_tree(ind << 1, left, mid, arr);
build_tree(ind << 1 | 1, mid + 1, right, arr);
update(ind);
}
template <typename Node, typename Label>
Node STree<Node, Label>::query_in(int ind, int left, int right, int start,
int end) {
if (start <= left && right <= end) {
return tree[ind];
}
if (slack[ind] && left != right){
down(ind, left, right);
}
//判断是否需要两边查询
int mid = (left + right) >> 1;
if (mid >= end) {
return query_in(ind << 1, left, mid, start, end);
} else if (mid < start) {
return query_in(ind << 1 | 1, mid + 1, right, start, end);
} else {
return f_merge(query_in(ind << 1, left, mid, start, end),
query_in(ind << 1 | 1, mid + 1, right, start, end));
}
}
template <typename Node, typename Label>
void STree<Node, Label>::modify_in(int ind, int left, int right, int start,
int end, const Label& value) {
if (value == 0 || right < start || left > end) {
return;
}
if (start <= left && right <= end) { //本区间包含在内
f_modify(tree[ind], value);
slack[ind] += value;
return;
}
if (slack[ind] && left != right) down(ind, left, right); //清除懒标记
int mid = (left + right) >> 1;
modify_in(ind << 1, left, mid, start, end, value);
modify_in(ind << 1 | 1, mid + 1, right, start, end, value);
update(ind);
return;
}
template <typename Node, typename Label>
void STree<Node, Label>::down(int ind, int left, int right) {
//需要调用者负责检查 是否应该调用该函数
int mid = (left + right) >> 1;
if (slack[ind << 1] && left != mid) {
down(ind << 1, left, mid);
}
if (slack[ind << 1 | 1] && mid + 1 != right) {
down(ind << 1 | 1, mid + 1, right);
}
f_modify(tree[ind << 1], slack[ind]); //标记下沉
f_modify(tree[ind << 1 | 1], slack[ind]); //标记下沉
update(ind);
//懒标记更新
slack[ind << 1] = slack[ind];
slack[ind << 1 | 1] = slack[ind];
slack[ind] = 0;
}
template <typename Node, typename Label>
void STree<Node, Label>::update(int ind) {
tree[ind] = f_merge(tree[ind << 1], tree[ind << 1 | 1]);
}
#undef ROOT
#endif
代码分析
使用方法
- 明确线段树的类型,涉及线段树哪些操作
- 明确数据类型 Node 需要维护的信息
- 明确区间合并时如何维护 Node 信息,设计区间合并函数
- 如果涉及区间修改操作,需要明确 Label 的数据类型,以及区间修改函数
- 构造 Node 数组,明确需要构造线段树的区间的范围
- 使用 Node 数组指针、区间两端点下标、合并函数指针、修改函数指针 构造线段树对象
- 使用成员函数进行 区间查询、单点修改、区间修改、单点查询 等操作
应用案例
1. 区间最大值
1.1题目
给定一个 n 位数组和两种操作: 操作1:修改数组中某个位置的值 操作2:查询数组中某个区间的最大值
输入
第一行输入两个整数
n
,
m
n,m
n,m
(
1
≤
n
≤
10000
,
3
≤
m
≤
100000
)
(1\le n\le 10000,3\le m\le 100000)
(1≤n≤10000,3≤m≤100000),分别代表数组大小和操作数。
第二行包含
n
n
n 个整数,代表数组中相应的数字,数字大小不会超过 int 表示范围。
接下来
m
m
m 行,每行三个整数
a
,
b
,
c
(
a
∈
[
1
,
2
]
)
a,b,c (a\in [1,2])
a,b,c(a∈[1,2])
当
a
=
1
a=1
a=1 时,代表将数组
b
b
b 位置上的值修改成
c
,
(
1
≤
b
≤
n
)
c,(1\le b\le n)
c,(1≤b≤n) c is int 。
当
a
=
2
a=2
a=2 时,代表询问
[
b
,
c
]
[b,c]
[b,c] 区间内的最大值
(
1
≤
b
,
c
≤
n
)
(1\le b,c\le n)
(1≤b,c≤n),当
b
>
c
b>c
b>c 时,输出:
−
2147483648
-2147483648
−2147483648
输出
对于每个
a
=
2
a=2
a=2 的操作,输出查询区间内的最大值。
样例输入
6 5
6 9 4 3 2 1
2 2 5
1 2 3
2 1 6
1 5 12
2 1 6
样例输出
9
6
12
1.2思路
1.3代码
#include "s_tree.h"
#include <iostream>
#include <cstdio>
using namespace std;
int arr[10005];
int func(const int& l, const int& r) { return max(l, r); }
int main()
{
int n, m;
cin >> n >> m;
for (int i = 1; i <= n; i++){
cin >> arr[i];
}
STree<int> stree(1, n, arr, func);
for (int i = 0; i < m; i++){
int a, b, c;
cin >> a >> b >> c;
if (a == 2){
//查询
if (b > c){
cout << -2147483648 << endl;
}
else {
cout << stree.query(b, c) << endl;
}
}
else {
//修改
stree.modify(b, c);
arr[b] = c;
}
}
return 0;
}
2. 区间和值
2.1题目
给定一个
n
n
n 位数组和两种操作:
操作1:数组中某个区间的所有数字加上一个值
操作2:查询数组中某个区间的所有数字之和
输入
第一行输入两个整数
n
,
m
(
1
≤
n
≤
10000
,
3
≤
m
≤
100000
)
n,m (1\le n\le 10000,3\le m\le 100000)
n,m(1≤n≤10000,3≤m≤100000),分别代表数组大小和操作数。
第二行包含
n
n
n 个整数,代表数组中相应的数字,数字大小不会超过 int 表示范围。
接下来
m
m
m行,每行三个或四个整数
a
,
b
,
c
,
d
(
a
∈
[
1
,
2
]
)
a,b,c,d(a\in [1,2])
a,b,c,d(a∈[1,2])
当
a
=
1
a=1
a=1 时,接下来输入
b
,
c
,
d
b,c,d
b,c,d,代表将数组
[
b
,
c
]
[b,c]
[b,c]区间内的数字加上
d
,
(
1
≤
b
,
c
≤
n
)
d,(1\le b,c\le n)
d,(1≤b,c≤n)
d
d
d is int
当
a
=
2
a=2
a=2 时,接下来输入
b
,
c
b,c
b,c,输入代表询问
[
b
,
c
]
[b,c]
[b,c] 区间内的和值
(
1
≤
b
,
c
≤
n
)
(1\le b,c\le n)
(1≤b,c≤n),当
b
>
c
b>c
b>c 时,输出
0
0
0。
输出
对于每个
a
=
2
a=2
a=2 的操作,输出查询区间内数字的和值,答案不会超过64位整型(long long)的表示范围。
样例输入
6 5
6 9 4 3 2 1
2 2 5
1 2 3 5
2 1 6
1 5 12 3
2 1 6
样例输出
18
35
41
2.2思路
2.3代码
#include "s_tree.h"
#include <iostream>
#include <cstdio>
using namespace std;
struct Node{
int length;
long long sum;
};
Node value[10005];
Node sum(const Node& x1, const Node& x2){
Node ans;
ans.sum = x1.sum + x2.sum;
ans.length = x1.length + x2.length;
return ans;
}
void modify(Node& node,const int& value){
node.sum += value * node.length;
}
bool operator==(const Node& n1, const Node& n2){
return (n1.sum == n2.sum && n1.length == n2.length);
}
int main()
{
int n, m;
scanf("%d%d", &n, &m);
for (int i = 1; i <= n; i++){
scanf("%lld",&value[i].sum);
value[i].length = 1;
}
STree<Node> stree(1, n + 5, value, sum, modify);
for (int i = 0; i < m; i++){
int a;
scanf("%d", &a);
if (a == 2){
//查询
int b, c;
scanf("%d%d", &b, &c);
if (b > c){
printf("0\n");
}
else {
printf("%lld\n", stree.query(b, c).sum);
}
}
else {
//修改
int b, c;
int d;
scanf("%d%d%d", &b, &c, &d);
if (b > c){
continue;
}
c = min(c, n);
stree.modify(b, c, d);
}
}
return 0;
}
3. 区间最大子段和
3.1题目
给定长度为
N
N
N 的数列
A
A
A,以及
M
M
M 条指令
(
N
≤
500000
,
M
≤
100000
)
(N\le 500000, M\le 100000)
(N≤500000,M≤100000),每条指令可能是以下两种之一:
“2 x y”,把
A
[
x
]
A[x]
A[x] 改成
y
y
y。
“1 x y”,查询区间
[
x
,
y
]
[x,y]
[x,y] 中的最大连续子段和,即
m
a
x
(
∑
i
=
l
r
A
[
i
]
)
(
x
≤
l
≤
r
≤
y
)
max(\sum_{i = l}^r A[i]) _{(x\le l\le r\le y)}
max(∑i=lrA[i])(x≤l≤r≤y)。
对于每个询问,输出一个整数表示答案。
输入
第一行两个整数
N
(
1
≤
N
≤
500000
)
,
M
(
1
≤
M
≤
100000
)
N(1≤N≤500000),M(1≤M≤100000)
N(1≤N≤500000),M(1≤M≤100000)
第二行
N
N
N 个整数
A
i
(
1
≤
∣
A
i
∣
≤
1000
)
Ai(1≤|Ai|≤1000)
Ai(1≤∣Ai∣≤1000)
接下来
M
M
M 行每行
3
3
3 个整数
k
,
x
,
y
,
k
=
1
k,x,y,k=1
k,x,y,k=1 表示查询(此时如果
x
>
y
x>y
x>y,请交换
x
,
y
x,y
x,y),
k
=
2
k=2
k=2表示修改
输出
对于每个询问输出一个整数表示答案。
输入样例1
5 3
1 2 -3 4 5
1 2 3
2 2 -1
1 3 2
输出样例1
2
-1
3.2思路
3.3代码
#include "s_tree.h"
#include <iostream>
#include <cstdio>
using namespace std;
struct Node {
long long sum, in_max, l_max, r_max;
};
Node arr[500005];
Node func(const Node& l, const Node& r) {
Node ans;
ans.in_max = max(l.in_max, r.in_max);
ans.in_max = max(ans.in_max, l.r_max + r.l_max);
ans.sum = l.sum + r.sum;
ans.l_max = max(l.l_max, l.sum + r.l_max);
ans.r_max = max(r.r_max, r.sum + l.r_max);
return ans;
}
bool operator==(const Node& x1, const Node& x2){
return (x1.sum == x2.sum && x1.in_max == x2.in_max &&
x1.r_max == x2.r_max && x1.l_max == x2.l_max);
}
int main()
{
int n, m;
scanf("%d%d",&n, &m);
for (int i = 1; i <= n; i++){
cin >> arr[i].sum;
arr[i].in_max = arr[i].sum;
arr[i].l_max = arr[i].in_max;
arr[i].r_max = arr[i].in_max;
}
STree<Node> stree(1, n, arr, func);
for (int i = 0; i < m; i++){
int a, b, c;
scanf("%d%d%d",&a, &b, &c);
if (a == 1){
//查询
if (b > c){
swap(b, c);
}
printf("%lld\n", stree.query(b, c).in_max);
}
else {
//修改
arr[b].sum = c;
arr[b].in_max = c;
arr[b].l_max = c;
arr[b].r_max = c;
stree.modify(b, arr[b]);
}
}
return 0;
}