代码模板:
https://github.com/rsy56640/rsy_little_lib/tree/master/library_for_algorithm/SegmentTree
线段树原理:
对于一个给定程度的区间
l
l
,我们递归地定义:
区间 的线段树为:
其根节点表示区间
l
l
;将其区间分为近似2段 ,左子树为区间
l1
l
1
的线段树,右子树为区间
l2
l
2
的线段树。
并且提供一个二元代数运算
o
o
,对于值类型构成一个幺半群。
线段树只提供两个方法:
1. 区间查询
2. 单点修改
O(logn)
O
(
l
o
g
n
)
SegmentTreeType.h
#pragma once
#ifndef _SEGMENTTREETYPE_H
#include <functional>
#include <memory>
#include <type_traits>
template<class _Ty>
class SegmentTreeNode;
template<class _Ty>
struct SegmentTreeType
{
using value_type = typename std::remove_reference<_Ty>::type;
using SegmentTreeNode_ptr = typename std::tr1::shared_ptr <SegmentTreeNode<_Ty> >;
using Func = typename std::tr1::function<_Ty(const _Ty&, const _Ty&)>;
};
#endif // !_SEGMENTTREETYPE_H
SegmentTreeNode.h
#pragma once
#ifndef _SEGMENTTREENODE_H
#include "SegmentTreeType.h"
//SegmentTreeNode Template
template<class _Ty> class SegmentTreeNode :public SegmentTreeType<_Ty>
{
public:
SegmentTreeNode(int start, int end, _Ty val)
:_start(start), _end(end), _val(val), _left(nullptr), _right(nullptr) {}
int start() const noexcept
{
return _start;
}
int end() const noexcept
{
return _end;
}
//value_type
_Ty val() const noexcept
{
return _val;
}
//const reference_type
const _Ty& value() const noexcept
{
return _val;
}
void setValue(_Ty&& value)
{
_val = _STD forward<_Ty>(value);
}
SegmentTreeNode_ptr& left() noexcept
{
return _left;
}
SegmentTreeNode_ptr& right() noexcept
{
return _right;
}
SegmentTreeNode(const SegmentTreeNode&) = delete;
SegmentTreeNode& operator=(const SegmentTreeNode&) = delete;
SegmentTreeNode(SegmentTreeNode&&) = delete;
SegmentTreeNode& operator=(SegmentTreeNode&&) = delete;
~SegmentTreeNode() = default;
private:
int _start, _end;
_Ty _val;
SegmentTreeNode_ptr _left, _right;
};
#endif // !_SEGMENTTREENODE_H
SegmentTreeException.h
#pragma once
#ifndef _SEGMENTTREEEXCEPTION_H
#include <exception>
#include <string>
template<class _Ty> class SegmentTreeException :public exception
{
public:
SegmentTreeException(std::string msg)
:_msg(msg) {}
const char* what() const noexcept
{
return "SegmentTree Exception";
}
private:
std::string _msg;
friend
std::ostream& operator<<(std::ostream& os, const SegmentTreeException<_Ty>& e)
{
os << e._msg;
return os;
}
};
#endif // !_SEGMENTTREEEXCEPTION_H
SegmentTreeImpl.h
#pragma once
#ifndef _SEGMENTTREEIMPL_H
#include "SegmentTreeNode.h"
#include "SegmentTreeException.h"
#include <vector>
using std::vector;
//SegmentTreeImpl Template
template<class _Ty> class SegmentTreeImpl :public SegmentTreeType<_Ty>
{
public:
//customized function constructor
SegmentTreeImpl(const vector<_Ty>& Vec, Func func, _Ty Identity_Element)
: _root(nullptr), _Func(func), _Identity_Element(Identity_Element), _checked(false)
{
//异常检查
if (Vec.empty())
throw SegmentTreeException<_Ty>("The Segment is empty!!");
//初始化线段树,空间复杂度O(n),时间复杂度O(n)
_root = build(0, Vec.size() - 1, Vec);
_checked = true;
}
_Ty query(int start, int end) const
{
if (!_checked)
throw SegmentTreeException<_Ty>("The Segment is empty!!");
if (start > end)
throw SegmentTreeException<_Ty>("The querying range is invalid!!");
return doQuery(_root, start, end);
}
void modify(int index, _Ty&& value)
{
if (!_checked)
throw SegmentTreeException<_Ty>("The Segment is empty!!");
if (index<_root->start() || index>_root->end())
throw SegmentTreeException<_Ty>("The Index is invalid!!");
doModify(_root, index, _STD forward<_Ty>(value));
}
SegmentTreeImpl(const SegmentTreeImpl&) = delete;
SegmentTreeImpl& operator=(const SegmentTreeImpl&) = delete;
SegmentTreeImpl(SegmentTreeImpl&&) = delete;
SegmentTreeImpl& operator=(SegmentTreeImpl&&) = delete;
~SegmentTreeImpl() = default;
protected:
SegmentTreeNode_ptr _root;
//_Func是一个_Ty上的二元代数运算符,满足结合律,有幺元,_Ty对_Func构成一个幺半群
Func _Func;
//幺元
const _Ty _Identity_Element;
//check out if SegmentTree exists
bool _checked;
private:
//SegmentTree Initialization
SegmentTreeNode_ptr build(int start, int end, const vector<_Ty>& Vec)
{
//leaf node
if (start == end)
return make_shared<SegmentTreeNode<_Ty> >(start, end, Vec[start]);
//internal node (non-leaf)
int mid = (start + end) / 2;
//construct this node with initial val(_Identity_Element)
SegmentTreeNode_ptr node =
make_shared<SegmentTreeNode<_Ty> >(start, end, _Identity_Element);
//construct left and right subTree (recursion)
node->left() = (build(start, mid, Vec));
node->right() = (build(mid + 1, end, Vec));
//set value
node->setValue(
_STD forward<_Ty>
(_Func(node->left()->value(), node->right()->value())));
return node;
}
//
_Ty doQuery(SegmentTreeNode_ptr root, int start, int end) const
{
//no segment union
if (start > root->end() || end < root->start())
return _Identity_Element;
//querying segment includes root segment
if (start <= root->start() && root->end() <= end)
return root->val();
//partially coincide
return _Func(doQuery(root->left(), start, end), doQuery(root->right(), start, end));
}
//
void doModify(SegmentTreeNode_ptr root, int index, _Ty&& value)
{
//leaf node found
if (root->start() == root->end() && root->start() == index)
{
root->setValue(_STD forward<_Ty>(value));
return;
}
//not found
int mid = (root->start() + root->end()) / 2;
//left subTree
if (index <= mid)
{
doModify(root->left(), index, _STD forward<_Ty>(value));
root->setValue(
//_STD forward<_Ty>
(_Func(root->left()->value(), root->right()->value())));
}
//right subTree
else
{
doModify(root->right(), index, _STD forward<_Ty>(value));
root->setValue(
//_STD forward<_Ty>
(_Func(root->left()->value(), root->right()->value())));
}
}
};
#endif // !_SEGMENTTREEIMPL_H
SegmentTree.h
#pragma once
#ifndef _SEGMENTTREE_H
#include "SegmentTreeImpl.h"
//SegmentTree Template
template<class _Ty> class SegmentTree :public SegmentTreeType<_Ty>
{
using PImpl = typename std::tr1::shared_ptr<SegmentTreeImpl<_Ty> >;
public:
SegmentTree(const vector<_Ty>& Vec, Func func, _Ty Identity_Element)
:_pImpl(make_shared<SegmentTreeImpl<_Ty> >(Vec, func, Identity_Element)) {}
//查询操作,时间复杂度O(logn)
_Ty query(int start, int end) const
{
return _pImpl->query(start, end);
}
//修改操作,时间复杂度O(logn)
void modify(int index, _Ty&& value)
{
_pImpl->modify(index, _STD forward<_Ty>(value));
}
private:
PImpl _pImpl;
};
#endif // !_SEGMENTTREE_H
main.cpp
#include "SegmentTree.h"
#include <iostream>
using namespace std;
int foo(int a, int b)
{
return (a > b) ? a : b;
}
int main()
{
vector<int> v = { 1,2,7,8,5 };
try
{
SegmentTree<int> h(v, foo, 0);
int a = h.query(0, 2);
h.modify(0, 4);
int b = h.query(0, 1);
h.modify(2, 11);
int c = h.query(2, 3);
}
catch (SegmentTreeException<int>& e)
{
cout << e << endl;
}
system("pause");
return 0;
}