线段树(Segment Tree)
目录
首先我们知道二叉树,balanced 二叉树可以保证查找的复杂度是logn的复杂度。理解起来十分直观:大概这样子:
之所以是平衡的,是因为它的构建方式,代码上能看出来。
如你所见,底层的就是单个的元素,每往上走一层,就会根据底层的元素做一个范围查询(rangeQuery)操作,这个操作可以是求和,求最小,求最大,同时上一层就会cover all the union of the children's range
there are four operations in total:
- build(start,end,vals)->O(n)
- update(index,value)->O(logn)
- rangeQuery(satrt,end)->O(logn+K) K is the number of the reported segments.
那我们应该如何实现呢?
其实整体的写法跟快排有异曲同工之妙,也是递归不断的寻找中间点,如果左右满足结束要求,就结束。
结构体
当然首先我们要定义它的每一个节点的结构,这里start,end是左右闭区间的。
struct SegmentTreeNode
{
int start;
int end;
int sumnum;
SegmentTreeNode* left;
SegmentTreeNode* right;
SegmentTreeNode(int s, int e, int n,SegmentTreeNode* l=NULL,SegmentTreeNode* r=NULL)
:start(s), end(e), sumnum(n), left(l), right(r) {};
};
构建
接下来我们借助递归来构建这棵树,当然如果是项目落地还是用用智能指针吧。实际上复杂度是2n啦,但是忽略掉2那就是n。如果是奇数,那么我们的构建方式保证左边多一个leaf。
SegmentTreeNode* buildTree(int start, int end, vector<int>& vals)
{
if (start == end)
{
return new SegmentTreeNode(start, end, vals[start]);
}
int mid = start + (end - start) / 2;
SegmentTreeNode* left = buildTree(start, mid, vals);
SegmentTreeNode* right = buildTree(mid + 1, end, vals);
return new SegmentTreeNode(start, end, left->sumnum + right->sumnum, left, right);
}
单个更新
接下来是更新,更新的操作其实跟构建差不多,我们借助递归的方式,不断地更新当前的node,直到我们找到了start==end==index之后,,我们更新它的值并且返回,在返回的途中,我们不断更新它的父节点的值。忽略掉每一步的更新操作那就是logn。
void updateTree(SegmentTreeNode* root, int index, int val)
{
if (root->start == root->end && root->start == index)
{
root->sumnum = val;
return;
}
int mid = root->start + (root->end - root->start) / 2;
if (index <= mid)
{
updateTree(root->left, index, val);
}
else
{
updateTree(root->right, index, val);
}
root->sumnum = root->left->sumnum + root->right->sumnum;
}
范围查询
接下来看看范围查询,挺好理解的, 比较的时候记住什么时候加等号就行。比如val的length是5,那么一半就是2,为了平衡可能是01在左,234在右,根绝构建方式的不同就不同,这里我们的构建方式是012在做,34在右,因此当有边界小于等于2的时候就在左边,而不是小于。
int rangeQuery(SegmentTreeNode* root, int left, int right)
{
if (root->start == left && root->end == right)
{
return root->sumnum;
}
int mid = root->start + (root->end - root->start) / 2;
if (right <= mid) //完全落在左边
{
return rangeQuery(root->left, left, right);
}
else if (left > mid) //完全落在右边
{
return rangeQuery(root->right, left, right);
}
else //落在中间
{
return rangeQuery(root->left, left, mid) + rangeQuery(root->right, mid + 1, right);
}
}
例题1
力扣https://leetcode-cn.com/problems/range-sum-query-mutable/
加个智能指针稍微改吧改吧:
class NumArray {
struct SegmentTreeNode
{
int start;
int end;
int sumnum;
unique_ptr<SegmentTreeNode> left;
unique_ptr<SegmentTreeNode> right;
SegmentTreeNode(int s, int e, int n,SegmentTreeNode* l=NULL,SegmentTreeNode* r=NULL)
:start(s), end(e), sumnum(n), left(l), right(r) {};
};
SegmentTreeNode* buildTree(int start, int end, vector<int>& vals)
{
if (start == end)
{
return new SegmentTreeNode(start, end, vals[start]);
}
int mid = start + (end - start) / 2;
SegmentTreeNode* left = buildTree(start, mid, vals);
SegmentTreeNode* right = buildTree(mid + 1, end, vals);
return new SegmentTreeNode(start, end, left->sumnum + right->sumnum, left, right);
}
void updateTree(unique_ptr<SegmentTreeNode>& root, int index, int val)
{
if (root->start == root->end && root->start == index)
{
root->sumnum = val;
return;
}
int mid = root->start + (root->end - root->start) / 2;
if (index <= mid)
{
updateTree(root->left, index, val);
}
else
{
updateTree(root->right, index, val);
}
root->sumnum = root->left->sumnum + root->right->sumnum;
}
int rangeQuery(unique_ptr<SegmentTreeNode>& root, int left, int right)
{
if (root->start == left && root->end == right)
{
return root->sumnum;
}
int mid = root->start + (root->end - root->start) / 2;
if (right <= mid) //完全落在左边
{
return rangeQuery(root->left, left, right);
}
else if (left > mid) //完全落在右边
{
return rangeQuery(root->right, left, right);
}
else //落在中间
{
return rangeQuery(root->left, left, mid) + rangeQuery(root->right, mid + 1, right);
}
}
unique_ptr<SegmentTreeNode> head;
public:
NumArray(vector<int>& nums) {
head=unique_ptr<SegmentTreeNode>(buildTree(0,nums.size()-1,nums));
}
void update(int index, int val) {
updateTree(head,index,val);
}
int sumRange(int left, int right) {
return rangeQuery(head,left,right);
}
};
/**
* Your NumArray object will be instantiated and called as such:
* NumArray* obj = new NumArray(nums);
* obj->update(index,val);
* int param_2 = obj->sumRange(left,right);
*/
寄啊,智能指针真慢,凑活看……
范围更新
如果数据量过大,比如1e9这种级别,我们不可能初始化这么大的一个二叉树,因此我们需要一种动态创建和删除结点的方式来创建这个线段树
struct Node
{
int left;
int right;
int sum;
int lazy;
Node* leftChild;
Node* rightChild;
Node(int l, int r, int s = 0,int la =0,Node* lc=NULL,Node* rc=NULL)
:left(l), right(r),sum(s),lazy(la),leftChild(lc),rightChild(rc)
{}
};
其实不一样的就是有个lazy值,这个值的意义在于,如果我们访问到这个结点的时候,我们需要知道这个结点的下层的某个结点的值,如果lazy值不为0,那么说明这些值需要被修正到lazy。
也就是说,只有我们在访问到被更改的结点的时候,我们才会修正这个结点的值,否则,那些不被我们关心的下层结点根本不会被动或者甚至不会被创建。
我们在两种情况下需要更新一个结点的lazy值:
- 更新区间的时候,如果区间被细分成了多个更细的区间,那么我们需要更新对应区间的值,
- 查询的时候,如果我们的兴趣区间的父节点的lazy值不为0,那么我们需要创建对应的值或者更新。
更新的过程就是创建并且将lazy值扔给两个子节点,并且修改其的sum值,这种更新方式就跟病毒传播一样,因此我们命名为infect。
void infect(Node* root)
{
int mid = (root->right - root->left) / 2 + root->left;
if (root->leftChild == NULL)
root->leftChild = new Node(root->left, mid);
if (root->rightChild == NULL)
root->rightChild = new Node(mid + 1, root->right);
if (root->lazy != 0)
{
root->leftChild->lazy = root->rightChild->lazy = root->lazy;
root->leftChild->sum = (mid - root->left + 1) * root->lazy;
root->rightChild->sum = (root->right - (mid + 1) + 1) * root->lazy;
root->lazy = 0;
}
}
有了感染函数,我们就能仿造之前的二叉树函数写出对应区间更新和查询函数:
void update(Node* root, int left, int right, int value)
{
if((root->left>right)||(root->right<left)) return;// out of range
if (root->left >= left && root->right <= right) //fully within the range
{
root->lazy = value;//change value
root->sum = (root->right - root->left + 1) * value;//change sum
}
else //partcially in range
{
infect(root);
update(root->leftChild, left, right, value);
update(root->rightChild, left, right, value);
root->sum = root->leftChild->sum + root->rightChild->sum;
}
}
int query(Node* root, int left, int right)
{
if (left <= root->left && root->right < right) return root->sum;
int mid = (root->right - root->left) / 2 + root->left;
infect(root);
if(right <= mid)
{
return query(root->leftChild, left, right);
}
if (left >= mid+1)
{
return query(root->rightChild, left, right);
}
return query(root->leftChild, left, mid) + query(root->rightChild, mid + 1, right);
}
例题2
力扣https://leetcode.cn/problems/range-module/ 有了上面的动态更新感染线段树,我们可以容易写出下面的代码:
struct Node
{
int left;
int right;
int sum;
int lazy;
//0:expired,1:modified,2:shared
Node* leftChild;
Node* rightChild;
Node(int l, int r, int s = 0,int la =0,Node* lc=NULL,Node* rc=NULL)
:left(l), right(r),sum(s),lazy(la),leftChild(lc),rightChild(rc)
{}
};
void infect(Node* root)
{
int mid = (root->right - root->left) / 2 + root->left;
if (root->leftChild == NULL)
root->leftChild = new Node(root->left, mid);
if (root->rightChild == NULL)
root->rightChild = new Node(mid + 1, root->right);
if (root->lazy != 2)
{
root->leftChild->lazy = root->rightChild->lazy = root->lazy;
root->leftChild->sum = (mid - root->left + 1) * root->lazy;
root->rightChild->sum = (root->right - (mid + 1) + 1) * root->lazy;
root->lazy = 2;
}
}
void update(Node* root, int left, int right, int value)
{
if((root->left>right)||(root->right<left)) return;// out of range
if (root->left >= left && root->right <= right) //fully within the range
{
root->lazy = value;//change value
root->sum = (root->right - root->left + 1) * value;//change sum
}
else //partcially in range
{
infect(root);
update(root->leftChild, left, right, value);
update(root->rightChild, left, right, value);
root->sum = root->leftChild->sum + root->rightChild->sum;
}
}
int query(Node* root, int left, int right)
{
if (left <= root->left && root->right <= right) return root->sum;
int mid = (root->right - root->left) / 2 + root->left;
infect(root);
if(right <= mid)
return query(root->leftChild, left, right);
if (left >= mid+1)
return query(root->rightChild, left, right);
return query(root->leftChild, left, mid) + query(root->rightChild, mid + 1, right);
}
class RangeModule {
Node* root;
public:
RangeModule() {
root = new Node(0, int(1e9));
}
void addRange(int left, int right) {
update(root, left, right-1, 1);
}
bool queryRange(int left, int right) {
return (right - left) == query(root, left, right - 1);
}
void removeRange(int left, int right) {
update(root, left, right - 1, 0);
}
};
麻中麻……