这是我做的第一道线段树的题目,以前只了解过模板——下面我就来介绍一下。
-------------------------------------------线段树的介绍-------------------------------------------------------
线段树是什么?就是一种二叉树结构,而且它的每一个结点代表一条线段的和。
以下是百度百科的详细解释:
一般线段树我都是用堆来储存的,即a[i]的左儿子是a[i*2],右儿子是a[i*2+1]。
线段树快在哪里呢?他对于每个结点的修改和访问都是log(n)的效率。
相信到这里都很好理解,下面介绍一种更加强大的思想:lazy思想。
如果我要将a--b的区间的每一个值都加上p,那么效率就是(b-a+1)log(n)。有没有更快的方法呢?
通过研究,我们发现:如果它只修改而并不访问,我们岂不是不用更新了!
于是我们可以再开一个域cnt,表示目前有多少次累加被积压着。每次处理到k结点时(注意,这个处理包括访问和修改等),我们都把它积压的值给加上,然后传给它的儿子。此时它的儿子不必处理,直到要处理到它的儿子为止。
下面是这强大的update代码:
void update(long k)
{
if (a[k].cnt==0) return;
a[k].sum+=(a[k].r-a[k].l+1)*a[k].cnt;
a[k*2].cnt+=a[k].cnt;a[k*2+1].cnt+=a[k].cnt;
a[k].cnt=0;
}
每次处理到k时,先做一遍update(k)再进行操作即可。
------------------------------------------------回到原题-------------------------------------------------------
这里可能相对简单一些,下面说一些与模板不太相同的细节。
(1)初始建树时,每个值都是0.
(2)在插入时,每次cnt累加1即可,这个1当然不是加上1,而是执行操作数。
(3)在update中,如果cnt是偶数,就可以直接清零(因为偶数次不是抵消了嘛!);如果是奇数,我们把它相反一下,并向下传1。
代码:
#include<stdio.h>
using namespace std;
const long maxn=100001;
struct tree
{
long l,r; //l和r就是a[i]所表示的左右区间范围。
long cnt,sum; //cnt是累加的操作次数,sum是a[i]表示的线段中开着的灯的数量。
}a[4*maxn];
long n,m,i,x,y,ok;
void build(long k,long l,long r) //基础建树
{
a[k].cnt=0;a[k].l=l;a[k].r=r;
if (l==r) return;
long mid=(l+r)/2;
build(k*2,l,mid);build(k*2+1,mid+1,r);
}
void update(long k) //update操作。
{
if (a[k].cnt==0) return;
if (a[k].cnt%2==1)
{
a[k].sum=a[k].r-a[k].l+1-a[k].sum; //a[k]的区间是a[k].l到a[k].r,那么共有a[k].r-a[k].l+1个,这里取反。
a[k*2].cnt+=1;a[k*2+1].cnt+=1;
a[k].cnt=0;
}
else a[k].cnt=0;
}
void ins(long k,long l,long r)
{
update(k);
if (a[k].l>=l&&a[k].r<=r) {a[k].cnt+=1;return;} //当前区域被覆盖
long mid=(a[k].l+a[k].r)/2; //二分找被覆盖的区域。
if (l<=mid) ins(k*2,l,r);
if (r>mid) ins(k*2+1,l,r);
update(k*2);update(k*2+1); //对儿子的更新操作(很重要),因为下面一步要重新更新当前的值。
a[k].sum=a[k*2].sum+a[k*2+1].sum;
}
long find(long k,long l,long r) //这与ins大同小异。
{
update(k);
if (a[k].l>=l&&a[k].r<=r) return a[k].sum;
long mid=(a[k].l+a[k].r)/2,o=0;
if (l<=mid) o+=find(k*2,l,r);
if (r>mid) o+=find(k*2+1,l,r);
update(k*2);update(k*2+1);
a[k].sum=a[k*2].sum+a[k*2+1].sum;
return o;
}
int main()
{
//freopen("lites.in","r",stdin);
//freopen("lites.out","w",stdout);
scanf("%ld%ld",&n,&m);
build(1,1,n);
for (i=1;i<=m;i++)
{
scanf("%ld%ld%ld",&ok,&x,&y);
if (ok==0) ins(1,x,y);
else {long ans=find(1,x,y);printf("%ld\n",ans);}
}
//scanf("%ld",&n,&m);
return 0;
}