线段树用于解决区间的问题,算是对分块算法的优化,但功能更为强大。
我们定义一颗树,树上存储的结构为区间、统计值等。
线段树的存储结构:
1.链表:
struct Node{
int left,right,sum;
Node *lchild,*rchild;
};
2.数组模拟链表法
定义4个int数组,left[],right[],lchid[],rchild[]来分别存储左端点,右端点,左孩子索引,右孩子索引。
3.数组形成的堆结构
对于一个索引i对应的区间为[l,r],如果区间不是叶子区间,则分出左孩子和右孩子,mid=(l+r)/2,左孩子索引为2*i,对应区间为[l,mid],右孩子索引为2*i+1,对应区间为[mid+1,r]。
4.更为紧凑的结构
这里要分析一下堆结构的存储损耗。我们直接考虑满二叉树的情形,节点个数为。
假设初始的区间为[1,n]。,加1是因为本身初始区间占据一层,于是保险起见,节点个数算得的结果为4n-1。
然而4n-1的区间长度是假象,只是因为在最后一层多了仅仅几个节点而导致索引范围的扩大,因此实际上真正存有区间的还是2n-1。
一种改进的方式是对于区间为[l,r)的节点,存储在(l+r-1) bitwise_or diff(1,r-1)的索引种,diff为判断传入的两值是否相等的函数,相等返回0,不等返回1。bitwise_or为按位或运算。
注意:前三种存储都可以存区间[l,r],但是第四种要求只能是[l,r)。本编文章以链式存储为主,并且博主本人习惯直接表示闭区间[l,r]。
递归建立线段树:
void build(Node *root,int l,int r){
root->left=l;
root->right=r;
int mid=(l+r)/2;
if(l!=r){//非叶子节点
root->lchild=new Node;
root->rchild=new Node;
build(root->lchild,l,mid);
build(root->rchild,mid+1,r);
root->sum=root->lchild->sum+root->rchild->sum;
}
root->lchild=root->rchild=null;
root->sum=Arr[l];
}
线段树的查询(以区间和为例):
int query(Node *root,int l,int r){
if(l<=root->left&&r>=root->right){//完全被包含的区间
return root->sum;
}
int ans=0,mid=(root->left+root->right)/2;
if(l<=mid){//如果有未计算完的部分在左半区间 递归
ans+=query(root->lchild,l,r)
}
if(r>mid){
ans+=query(root->rchild,l,r);
}
return ans;
}
解释: if(l<=mid){
ans+=query(root->lchild,l,r);
}
此处不调整递归函数的[l,r]区间的原因是,我们还不知道是剩下的区间全部落在左半边还是一部分落在左半边。虽然可以加一个判断,但这样就会变成4个if,增大了代码冗余度。而这种写法,并不会影响正误,我们返回区间的结果仅仅看这个区间是否落在所要求的区间内,也不会发生重叠,因为我们只返回了父亲,而交叠只会发生在父子之间。
线段树的单点修改:
void change(Node *root,int locate,int delta){//单点更新
if(root->left==root->right){//要修改的叶子
root->sum+=delta;
}
int mid=(root->left+root->right)/2;
if(locate<=mid){//目标在左半区间
change(root->lchild,locate,delta);
}
else{//目标在右半区间
change(root->rchild,locate,delta);
}
root->sum=root->lchild->sum+root->rchild->sum;
}
lazy操作:
以上图为例,如果我们对一个区间依旧采用单点更新的话,那么一次点更新的复杂度为O(logn),m个点的更新复杂度则变为O(mlogn)。但实际上,假设我们想要更新[2,6]区间的话,在这棵线段树里的递归过程首先找到的父节点为[2,2],[3,4],[5,6],此次是没有必要更新到最底的,先更新到父亲上,如果后期需要查询或者更新上述父节点的子区间时再往下更新即可,这就是延迟标记,也称懒标记(lazy)。
延迟标记下传函数:
void push_down(Node *root){
root->lchild->sum+=root->delta*(root->lchild->right-root->lchild->left+1);
root->rchild->sum+=root->delta*(root->rchild->right-root->rchild->left+1);
root->lchild->delta=root->rchild->delta=root->delta;
root->delta=0;
}
带有lazy标记的区间更新:
void modify(Node *root,int l,int r,int delta){
if(l<=root->left&&r>=root->right){//属于要更新的区间
root->sum+=(root->right-root->left+1)*delta;//区间更新
root->delta+=delta;//记上标记
return;
}
//当前区间过大 要往下走
if(root->delta!=0){//当前节点有未下传的标记
push_down(root);
}
int mid=(root->left+root->right)/2;
if(l<=mid){
modify(root->lchild,l,r,delta);
}
if(r>mid){
modify(root->rchild,l,r,delta);
}
root->sum=root->left->sum+root->right->sum;
}
带有lazy标记的区间查询:
int query(Node *root,int l,int r){
if(l<=root->left&&r>=root->right){//完全被包含的区间
return root->sum;
}
if(root->delta!=0)push_down(root);
int ans=0,mid=(root->left+root->right)/2;
if(l<=mid){//如果有未计算完的部分在左半区间 递归
ans+=query(root->lchild,l,r)
}
if(r>mid){
ans+=query(root->rchild,l,r);
}
return ans;
}
poj3468:数据量大,建点会T
#include<iostream>
using namespace std;
typedef long long ll;
int N,Q;
ll Arr[100005];
struct Node{
int left,right;
ll sum,delta;
Node *lchild,*rchild;
};
void build(Node* root,int l,int r){
root->left=l;
root->right=r;
root->delta=0;
int mid=(root->left+root->right)/2;
if(l!=r){//非叶子
root->lchild=new Node;
root->rchild=new Node;
build(root->lchild,l,mid);
build(root->rchild,mid+1,r);
root->sum=root->lchild->sum+root->rchild->sum;
return;
}
root->lchild=root->rchild=NULL;
root->sum=Arr[l];
}
void push_down(Node *root){
root->lchild->sum+=root->delta*(root->lchild->right-root->lchild->left+1);
root->rchild->sum+=root->delta*(root->rchild->right-root->rchild->left+1);
root->lchild->delta+=root->delta;
root->rchild->delta+=root->delta;
root->delta=0;
}
ll query(Node *root,int l,int r){
if(l<=root->left&&r>=root->right){
return root->sum;
}
else{
if(root->delta!=0)push_down(root);
ll ans=0;
int mid=(root->left+root->right)/2;
if(l<=mid){
ans+=query(root->lchild,l,r);
}
if(r>mid){
ans+=query(root->rchild,l,r);
}
return ans;
}
}
void modify(Node *root,int l,int r,ll delta){
if(l<=root->left&&r>=root->right){
root->sum+=delta*(root->right-root->left+1);
root->delta+=delta;
}
else{
if(root->delta!=0)push_down(root);
int mid=(root->left+root->right)/2;
if(l<=mid){
modify(root->lchild,l,r,delta);
}
if(r>mid){
modify(root->rchild,l,r,delta);
}
}
}
int main(){
cin>>N>>Q;
int i,l,r;
ll delta;
char op;
Node *root=new Node;
for(i=1;i<=N;i++)cin>>Arr[i];
build(root,1,N);
for(i=1;i<=Q;i++){
cin>>op;
switch(op){
case 'Q':cin>>l>>r;
cout<<query(root,l,r)<<endl;
break;
case 'C':cin>>l>>r>>delta;
modify(root,l,r,delta);
break;
}
}
return 0;
}
经试验,,TLE很可能是读的问题,关闭流同步即可AC,以下是使用堆的线段树版本。
#include<iostream>
using namespace std;
typedef long long ll;
int N,Q,i,L,R;
ll Arr[100005],delta;
char op;
struct Node{
int l,r;
ll sum,delta;
};
Node Tree[400020];
void build(int root,int l,int r){
Tree[root].l=l;
Tree[root].r=r;
Tree[root].delta=0;
if(l!=r){
int mid=(l+r)/2,son1=root*2;
int son2=son1+1;
build(son1,l,mid);
build(son2,mid+1,r);
Tree[root].sum=Tree[son1].sum+Tree[son2].sum;
return;
}
Tree[root].sum=Arr[l];
}
void push_down(int root){//不需要考虑叶子节点下传越界的情况 因为如果查到了叶子节点 他必定是属于结果的一部分 不会经过下传过程
int son=root*2;
Tree[son].sum+=(Tree[root].delta)*(Tree[son].r-Tree[son].l+1);
Tree[son].delta+=Tree[root].delta;
son++;
Tree[son].sum+=(Tree[root].delta)*(Tree[son].r-Tree[son].l+1);
Tree[son].delta+=Tree[root].delta;
Tree[root].delta=0;
}
ll query(int root,int l,int r){
if(l<=Tree[root].l&&r>=Tree[root].r){
return Tree[root].sum;
}
else{
if(Tree[root].delta!=0)push_down(root);
ll ans=0;
int mid=(Tree[root].l+Tree[root].r)/2;
if(l<=mid){
ans+=query(root*2,l,r);
}
if(r>mid){
ans+=query(root*2+1,l,r);
}
return ans;
}
}
void modify(int root,int l,int r){
if(l<=Tree[root].l&&r>=Tree[root].r){
Tree[root].sum+=delta*(Tree[root].r-Tree[root].l+1);
Tree[root].delta+=delta;
return;
}
else{
if(Tree[root].delta!=0)push_down(root);
int mid=(Tree[root].l+Tree[root].r)/2;
if(l<=mid){
modify(root*2,l,r);
}
if(r>mid){
modify(root*2+1,l,r);
}
Tree[root].sum=Tree[root*2].sum+Tree[root*2+1].sum;
return;
}
}
int main(){
std::ios::sync_with_stdio(false);
cin>>N>>Q;
for(i=1;i<=N;i++)cin>>Arr[i];
build(1,1,N);
for(i=1;i<=Q;i++){
cin>>op;
switch(op){
case 'Q':cin>>L>>R;
cout<<query(1,L,R)<<endl;
break;
case 'C':cin>>L>>R>>delta;
modify(1,L,R);
break;
}
}
}