上图 from 熊掌搜索
- 类似数据结构:树状数组
1. 概念
线段树是一种二叉树,是用来表示一个区间的树:
- 常常用来查询区间的:和、最小值、最大值
- 树结点中存放不是普通二叉树的值,其结点结构如下
class TreeNode
{
public:
int sum;//区间和
int MAX;//区间最大的
int MIN;//区间最小的
int start, end;//区间左右端点
TreeNode *left, *right;//左右节点
TreeNode(int s, int e, int v):start(s),end(e),sum(v)
{
left = right = NULL;
MAX = v;
MIN = v;
}
};
2. 建树
- 传入数组,及其左右极限端点
- 自底向上建树
TreeNode* build(vector<int>& A, int L, int R)
{
if(L > R)
return NULL;
TreeNode* rt = new TreeNode(L,R,A[L]);
if(L == R)
return rt;
int mid = L+((R-L)>>1);//对半分开
rt->left = build(A,L,mid);
rt->right = build(A,mid+1,R);
rt->sum = 0;
if(rt->left)
{
rt->sum += rt->left->sum;
rt->MAX = max(rt->MAX, rt->left->MAX);
rt->MIN = min(rt->MIN, rt->left->MIN);
}
if(rt->right)
{
rt->sum += rt->right->sum;
rt->MAX = max(rt->MAX, rt->right->MAX);
rt->MIN = min(rt->MIN, rt->right->MIN);
}
return rt;
}
3. 查询
- 时间复杂度: O ( log n ) O(\log n) O(logn)
vector<int> query(TreeNode *rt, int s, int e)//查询区间的sum,min,max
{
if(s > rt->end || e < rt->start)
return {0, INT_MAX, INT_MIN};//没有交集
if(s <= rt->start && rt->end <= e)
return {rt->sum, rt->MIN, rt->MAX};//完全包含区间,取其值
//不完全包含,左右查找
vector<int> l = query(rt->left, s, e);
vector<int> r = query(rt->right,s, e);
//汇总信息
vector<int> summary(3);
summary[0] = l[0] + r[0];
summary[1] = min(l[1], r[1]);
summary[2] = max(l[2], r[2]);
return summary;
}
4. 修改
- 时间复杂度: O ( log n ) O(\log n) O(logn)
void modify(TreeNode *rt, int id, int val)
{
if(rt->start == rt->end)
{ //叶子节点
rt->sum = val;//和为自身
rt->MAX = val;
rt->MIN = val;
data[id] = val;
return;
}
int mid = (rt->start + rt->end)/2;
if(id > mid)
modify(rt->right, id, val);
else
modify(rt->left, id, val);
root->sum = 0;
if(rt->left)
{
rt->sum += rt->left->sum;
rt->MAX = max(rt->MAX, rt->left->MAX);
rt->MIN = min(rt->MIN, rt->left->MIN);
}
if(rt->right)
{
rt->sum += rt->right->sum;
rt->MAX = max(rt->MAX, rt->right->MAX);
rt->MIN = min(rt->MIN, rt->right->MIN);
}
}
5. 完整代码及测试
/**
* @description: 线段树
* @author: michael ming
* @date: 2020/3/13 0:21
* @modified by:
* @Website: https://michael.blog.csdn.net/
*/
#include<vector>
#include<iostream>
#include<climits>
using namespace std;
class TreeNode
{
public:
int sum;//区间和
int MAX;//区间最大的
int MIN;//区间最小的
int start, end;//区间左右端点
TreeNode *left, *right;//左右节点
TreeNode(int s, int e, int v):start(s),end(e),sum(v)
{
left = right = NULL;
MAX = v;
MIN = v;
}
};
class SegmentTree
{
public:
TreeNode* root;
vector<int> data;
SegmentTree(vector<int>& A)
{
root = build(A, 0, A.size()-1);
data = A;
}
~SegmentTree()
{
destroy(root);
}
void destroy(TreeNode* rt)
{
if(!rt) return;
destroy(rt->left);
destroy(rt->right);
delete rt;
}
TreeNode* build(vector<int>& A, int L, int R)
{
if(L > R)
return NULL;
TreeNode* rt = new TreeNode(L,R,A[L]);
if(L == R)
return rt;
int mid = L+((R-L)>>1);//对半分开
rt->left = build(A,L,mid);
rt->right = build(A,mid+1,R);
rt->sum = 0;
if(rt->left)
{
rt->sum += rt->left->sum;
rt->MAX = max(rt->MAX, rt->left->MAX);
rt->MIN = min(rt->MIN, rt->left->MIN);
}
if(rt->right)
{
rt->sum += rt->right->sum;
rt->MAX = max(rt->MAX, rt->right->MAX);
rt->MIN = min(rt->MIN, rt->right->MIN);
}
return rt;
}
vector<int> query(TreeNode *rt, int s, int e)//查询区间的sum,min,max
{
if(s > rt->end || e < rt->start)
return {0, INT_MAX, INT_MIN};//没有交集
if(s <= rt->start && rt->end <= e)
return {rt->sum, rt->MIN, rt->MAX};//完全包含区间,取其值
//不完全包含,左右查找
vector<int> l = query(rt->left, s, e);
vector<int> r = query(rt->right,s, e);
//汇总信息
vector<int> summary(3);
summary[0] = l[0] + r[0];
summary[1] = min(l[1], r[1]);
summary[2] = max(l[2], r[2]);
return summary;
}
void modify(TreeNode *rt, int id, int val)
{
if(rt->start == rt->end)
{ //叶子节点
rt->sum = val;//和为自身
rt->MAX = val;
rt->MIN = val;
data[id] = val;
return;
}
int mid = (rt->start + rt->end)/2;
if(id > mid)
modify(rt->right, id, val);
else
modify(rt->left, id, val);
root->sum = 0;
if(rt->left)
{
rt->sum += rt->left->sum;
rt->MAX = max(rt->MAX, rt->left->MAX);
rt->MIN = min(rt->MIN, rt->left->MIN);
}
if(rt->right)
{
rt->sum += rt->right->sum;
rt->MAX = max(rt->MAX, rt->right->MAX);
rt->MIN = min(rt->MIN, rt->right->MIN);
}
}
};
//-------------test---------------------
void printVec(vector<int> &a)
{
for(auto& ai : a)
cout << ai << " ";
cout << endl;
}
int main()
{
vector<int> v = {1,2,7,8,5};
printVec(v);
cout << "建立线段树" << endl;
SegmentTree sgtree(v);
printVec(sgtree.data);
cout << "查询区间的sum,MIN,MAX" << endl;
vector<int> qy_res = sgtree.query(sgtree.root,1,3);
printVec(qy_res);
cout << "修改某位置的值" << endl;
sgtree.modify(sgtree.root,1,100);
printVec(sgtree.data);
cout << "查询区间的sum,MIN,MAX" << endl;
qy_res = sgtree.query(sgtree.root,1,3);
printVec(qy_res);
return 0;
}
运行结果:valgrind ./a.out
==16895== Memcheck, a memory error detector
==16895== Copyright (C) 2002-2017, and GNU GPL'd, by Julian Seward et al.
==16895== Using Valgrind-3.14.0 and LibVEX; rerun with -h for copyright info
==16895== Command: ./a.out
==16895==
1 2 7 8 5
建立线段树
1 2 7 8 5
查询区间的sum,MIN,MAX
17 2 8
修改某位置的值
1 100 7 8 5
查询区间的sum,MIN,MAX
115 7 100
==16895==
==16895== HEAP SUMMARY:
==16895== in use at exit: 0 bytes in 0 blocks
==16895== total heap usage: 29 allocs, 29 frees, 616 bytes allocated
==16895==
==16895== All heap blocks were freed -- no leaks are possible
==16895==
==16895== For counts of detected and suppressed errors, rerun with: -v
==16895== ERROR SUMMARY: 0 errors from 0 contexts (suppressed: 0 from 0)