本题是线段树的模板题,对于初学者来说难度适中,我们需要对代码段中的一些细节进行一些推敲。
- 线段树是对数组的区间的操作,我们需要注意的是,在表示区间下标的时候,最好用1~n,不要用0~n-1,这样可以帮助我们更简洁的表示一段区间的起点、终点及其区间长度。这个细节会贯穿整段代码。而且本题的数据范围是long long,所以我们可以将代码中所有的变量都定义为long long型,方便阅读和理解。
- 在表示父节点与其子节点的时候,因为我们的数组下标设置为从1开始了,那么假设父节点下标为i,其子节点的下表就分别为2*i和2*i+1。在找他的左右子节点的下标的时候可以写两个函数,这样就不用一直写2*i和2*i+1了,而且函数名更容易理解。乘2、加1可以用位运算代替。这部分代码如下:
//左节点的下标
ll ls(ll x)
{
return x<<1;
}
//右节点的下标
ll rs(ll x)
{
return (x<<1)|1;//左移后其最后一位是0,|1可以代替+1。
}
- 原数组应该储存在一个单独的数组内,方便建立线段树。延迟标记tag在建立线段树的时候初始化。建立线段树的代码如下:
void build(ll root,ll l,ll r)
{
tag[root]=0;
if(l==r)
ans[root]=a[l];
else
{
ll mid=(l+r)>>1;
build(ls(root),l,mid);
build(rs(root),mid+1,r);
ans[root]=ans[ls(root)]+ans[rs(root)];
}
}
- 区间更新是个非常麻烦的操作,主要分为向下传递延迟标记和更新结点信息两部分。若当前遍历到的区间包含于要更新的区间内,那么就直接更新,不需要进行传递标记操作。若不是的话,需要先进行一次传递标记操作,再递归地对左右子区间进行更新,最后更新当前节点信息。注意一个节点可能同时被改变了多次,所以在更新其延迟标记时应该是累加起来。更新结点信息操作的代码如下:
//root表示当前遍历到的结点
//u_l、u_r表示要更新的区间的起点和终点
//nl、nr表示当前结点所表示的区间的起点和终点。
//val表示要加的值
void update(ll root,ll u_l,ll u_r,ll nl,ll nr,ll val)
{
if(u_l<=nl && u_r>=nr)
{
ans[root]+=val*(nr-nl+1);
tag[root]+=val;
return;
}
//向下传递延迟标记函数
pushdown(root,nl,nr);
ll mid=(nl+nr)>>1;
if(u_l<=mid)
update(ls(root),u_l,u_r,nl,mid,val);
if(u_r>mid)
update(rs(root),u_l,u_r,mid+1,nr,val);
ans[root]=ans[ls(root)]+ans[rs(root)];
}
- 向下传递延迟标记主要是为了根据父节点的延迟标记信息来更新子节点的信息,这里注意更新子节点的值的时候,我们要计算子节点的区间长度,不要误用了父节点的区间长度,最后不要忘了清空父节点的延迟标记。代码如下:
//root表示父节点的下标
//l、r表示父节点所表示的区间的起点和终点
void pushdown(ll root,ll l,ll r)
{
ll mid=(l+r)>>1;
//f函数为更新子节点信息的函数
f(ls(root),l,mid,tag[root]);
f(rs(root),mid+1,r,tag[root]);
tag[root]=0;
}
void f(ll root,ll l,ll r,ll val)
{
tag[root]+=val;
ans[root]+=val*(r-l+1);
}
- 最后是查询操作,若当前遍历到的区间被包含在了要查询的区间内,则直接返回当前节点的信息。否则,我们就要向下递归查找子节点的信息,所以此时我们应该进行一次向下传递延迟标记操作。代码如下:
//root表示当前遍历到的结点的下标
//q_l、q_r表示要查询的区间的起点和终点
//n_l、n_r表示当前遍历到的结点所代表的区间
ll query(ll root,ll q_l,ll q_r,ll n_l,ll n_r)
{
if(q_l<=n_l && q_r>=n_r)
return ans[root];
pushdown(root,n_l,n_r);
ll res=0;
ll mid=(n_l+n_r)>>1;
if(q_l<=mid)
res+=query(ls(root),q_l,q_r,n_l,mid);
if(q_r>mid)
res+=query(rs(root),q_l,q_r,mid+1,n_r);
return res;
}
- 最后附上完整源代码
#include<iostream>
#include<cstdio>
#define MAXN 1000001
#define ll long long
using namespace std;
ll n,m,t,x,y,k,a[MAXN],ans[MAXN<<2],tag[MAXN<<2];
ll ls(ll x)
{
return x<<1;
}
ll rs(ll x)
{
return (x<<1)|1;
}
void build(ll root,ll l,ll r)
{
tag[root]=0;
if(l==r)
ans[root]=a[l];
else
{
ll mid=(l+r)>>1;
build(ls(root),l,mid);
build(rs(root),mid+1,r);
ans[root]=ans[ls(root)]+ans[rs(root)];
}
}
void f(ll root,ll l,ll r,ll val)
{
tag[root]+=val;
ans[root]+=val*(r-l+1);
}
void pushdown(ll root,ll l,ll r)
{
ll mid=(l+r)>>1;
f(ls(root),l,mid,tag[root]);
f(rs(root),mid+1,r,tag[root]);
tag[root]=0;
}
void update(ll root,ll u_l,ll u_r,ll nl,ll nr,ll val)
{
if(u_l<=nl && u_r>=nr)
{
ans[root]+=val*(nr-nl+1);
tag[root]+=val;
return;
}
pushdown(root,nl,nr);
ll mid=(nl+nr)>>1;
if(u_l<=mid)
update(ls(root),u_l,u_r,nl,mid,val);
if(u_r>mid)
update(rs(root),u_l,u_r,mid+1,nr,val);
ans[root]=ans[ls(root)]+ans[rs(root)];
}
ll query(ll root,ll q_l,ll q_r,ll n_l,ll n_r)
{
if(q_l<=n_l && q_r>=n_r)
return ans[root];
pushdown(root,n_l,n_r);
ll res=0;
ll mid=(n_l+n_r)>>1;
if(q_l<=mid)
res+=query(ls(root),q_l,q_r,n_l,mid);
if(q_r>mid)
res+=query(rs(root),q_l,q_r,mid+1,n_r);
return res;
}
int main()
{
cin>>n>>m;
for(ll i=1;i<=n;i++)
cin>>a[i];
build(1,1,n);
while(m--)
{
cin>>t;
if(t==1)
{
scanf("%lld%lld%lld",&x,&y,&k);
update(1,x,y,1,n,k);
}
else
{
scanf("%lld%lld",&x,&y);
printf("%lld\n",query(1,x,y,1,n));
}
}
return 0;
}