1.线段树是一种二叉搜索树,与区间树相似,它将一个区间划分成一些单元区间,每个单元区间对应线段树中的一个叶结点。这样的数据结构有助于我们对区间的处理。
2.我们举求区间和(区间最值其实差不多),并需要对区间元素进行修改为例,如下题
题目描述
如题,已知一个数列,你需要进行下面两种操作:
1.将某区间每一个数加上x
2.求出某区间每一个数的和
输入输出格式
输入格式:
第一行包含两个整数N、M,分别表示该数列数字的个数和操作的总个数。
第二行包含N个用空格分隔的整数,其中第i个数字表示数列第i项的初始值。
接下来M行每行包含3或4个整数,表示一个操作,具体如下:
操作1: 格式:1 x y k 含义:将区间[x,y]内每个数加上k
操作2: 格式:2 x y 含义:输出区间[x,y]内每个数的和
输出格式:
输出包含若干行整数,即为所有操作2的结果。
输入输出样例
输入样例#1:
5 5
1 5 4 2 3
2 2 4
1 2 3 2
2 3 4
1 1 5 1
2 1 4
输出样例#1:
11
8
20
说明
时空限制:1000ms,128M
数据规模:
对于30%的数据:N<=8,M<=10
对于70%的数据:N<=1000,M<=10000
对于100%的数据:N<=100000,M<=100000
(数据已经过加强_,保证在int64/long long数据范围内)
样例说明:
(1)如题,我们需要的进行的操作只有1.求某段区间和 2.将某段区间的元素的值增加
(2)很明显,要经过这样的操作,我们首先考虑就是通过前缀和相减,循环添加值,但是题目中的操作极多,如果直接模拟势必超时,于是我们就想到了专门用来处理区间问题的线段树
线段树
1、每个线段树的节点,都必须包含三个元素,分别用来存储区间头和区间尾以及对应的值,当然通常情况下我们都会加上一个标记数用来存储修改操作,因为如果我们遇到一次修改操作就将一颗树全部模拟修改,那么我们的线段树优化就等于没有了。通常情况使用一个结构体封装,好遍历,如下
struct node{
long long int l,r,f,v;
}tree[maxn];//这里的maxn通常要设为n*4,因为如果是一个满二叉树就必须要4*n,以防万一嘛
2.线段树的初始化及输入
因为我们这里使用线段树是为求区间和,那么我们就规定
每一个父节点就等于两个子节点之和
使用递归的方式输入,如下
void creat(long long int x,long long int y,long long int k)
{
tree[k].l=x;
tree[k].r=y;
if(x==y)//如果到底了就输入,然后层层递归回溯,给父节点赋值
{
scanf("%d",&tree[k].v);
return ;
}
long long int mid=(x+y)/2;//如果没有到底就继续向下
creat(x,mid,k*2);
creat(mid+1,y,k*2+1);
tree[k].v=tree[k*2].v+tree[k*2+1].v;
}
特殊说明
每个结构体中我们需要声明一个懒标记,用来存储对该区间所有的值的改变(一个标记只能代表一种算术情况,这里给出加法),每当扫到这个节点且该标记不为0,那就去依照这个标记改变一下其子节点的值,如下
void passdown(int k)
{
tree[k*2].f+=tree[k].f;
tree[k*2+1].f+=tree[k].f;
tree[k*2].v+=tree[k].f*(tree[k*2].r-tree[k*2].l+1);
tree[k*2+1].v+=tree[k].f*(tree[k*2+1].r-tree[k*2+1].l+1);
tree[k].f=0;//此标记已经被使用了,就变成0
}
3.实现求区间和操作
如何实现求区间[a,b]和?
我们只需要层层递归下去
(1)当遍历到的区间被包含在目标区间内,我们直接将sum(声明为全局变量,每次操作初始化为0)加上该被包含区间的和值就好了
(2)当不包含时,我们需要从线段树的性质入手,直接取扫到的区间的中点进行判定,如下
void check(long long int x,long long int y,long long int k)
{
if(tree[k].l>=x&&tree[k].r<=y)
{
sum+=tree[k].v;
return ;
}
if(tree[k].f) passdown(k);
long long int mid=(tree[k].l+tree[k].r)/2;
if(x<=mid) check(x,y,k*2);//扫左边的子节点
if(y>mid) check(x,y,k*2+1);//扫右边的子节点
}
4.对去区间值做加法
直接去扫,扫到被包含区间就加,没扫到继续二分下去查找
如下
void add(int x,int y,int k,int w)
{
if(tree[k].l>=x&&tree[k].r<=y)
{
tree[k].v+=w*(tree[k].r-tree[k].l+1);
tree[k].f+=w;
return ;
}
if(tree[k].f) passdown(k);//同理进行向下调整
int mid=(tree[k].l+tree[k].r)/2;
if(x<=mid) add(x,y,k*2,w);
if(y>mid) add(x,y,k*2+1,w);
tree[k].v=tree[k*2].v+tree[k*2+1].v;
}
5.附上AC代码
#include<cstring>
#include<cstdio>
#include<iostream>
#include<cstdlib>
#include<algorithm>
#include<cmath>
#define maxn 5000005
using namespace std;
struct node{
long long int l,r,f,v;
}tree[maxn];
long long int n,m,x;
void creat(long long int x,long long int y,long long int k)
{
tree[k].l=x;
tree[k].r=y;
if(x==y)
{
scanf("%d",&tree[k].v);
return ;
}
long long int mid=(x+y)/2;
creat(x,mid,k*2);
creat(mid+1,y,k*2+1);
tree[k].v=tree[k*2].v+tree[k*2+1].v;
}
long long int sum=0;
void passdown(long long int k)
{
tree[k*2].f+=tree[k].f;
tree[k*2+1].f+=tree[k].f;
tree[k*2].v+=tree[k].f*(tree[k*2].r-tree[k*2].l+1);
tree[k*2+1].v+=tree[k].f*(tree[k*2+1].r-tree[k*2+1].l+1);
tree[k].f=0;
}
void add(long long int x,long long int y,long long int k,long long int w)
{
if(tree[k].l>=x&&tree[k].r<=y)
{
tree[k].v+=w*(tree[k].r-tree[k].l+1);
tree[k].f+=w;
return ;
}
if(tree[k].f) passdown(k);
long long int mid=(tree[k].l+tree[k].r)/2;
if(x<=mid) add(x,y,k*2,w);
if(y>mid) add(x,y,k*2+1,w);
tree[k].v=tree[k*2].v+tree[k*2+1].v;
return ;
}
void check(long long int x,long long int y,long long int k)
{
if(tree[k].l>=x&&tree[k].r<=y)
{
sum+=tree[k].v;
return ;
}
if(tree[k].f) passdown(k);
long long int mid=(tree[k].l+tree[k].r)/2;
if(x<=mid) check(x,y,k*2);
if(y>mid) check(x,y,k*2+1);
}
long long int a,b,c;
int flag;
long long int xi;
int main()
{
cin>>n>>m;
creat(1,n,1);
// for(int i=1;i<=10;i++)
// cout<<tree[i].v<<" "<<endl;
for(int i=1;i<=m;i++)
{
sum=0;
cin>>flag;
if(flag==2)
{
scanf("%lld %lld",&a,&b);
check(a,b,1);
cout<<sum<<endl;
}
else
{
cin>>a>>b>>c;
add(a,b,1,c);
}
}
return 0;
}
下面这个代码是POJ3468的
#include<cstdio>
#include<iostream>
#include<algorithm>
#include<cstring>
#define maxn 4000000
using namespace std;
int n,m;
struct que{
long long int l,r,f,v;
}tr[maxn];
void creat( int x, int y,int k)
{
tr[k].l=x;
tr[k].r=y;
if(x==y)
{
scanf("%lld",&tr[k].v);
return;
}
int mid=(tr[k].l+tr[k].r)/2;
creat(x,mid,k*2);
creat(mid+1,y,k*2+1);
tr[k].v=tr[k*2].v+tr[k*2+1].v;
}
void passdown(int k)
{
tr[k*2].f+=tr[k].f;
tr[k*2+1].f+=tr[k].f;
tr[k*2].v+=tr[k].f*(tr[k*2].r-tr[k*2].l+1);
tr[k*2+1].v+=tr[k].f*(tr[k*2+1].r-tr[k*2+1].l+1);
tr[k].f=0;
}
long long int sum=0;
void query(int x, int y,int k)
{
if(x<=tr[k].l&&y>=tr[k].r)
{
sum+=tr[k].v;
return ;
}
if(tr[k].f) passdown(k);
int mid=(tr[k].l+tr[k].r)/2;
if(x<=mid) query(x,y,k*2);
if(y>mid) query(x,y,k*2+1);
}
void add(int x,int y,int w,int k)
{
if(x<=tr[k].l&&tr[k].r<=y)
{
tr[k].v+=w*(tr[k].r-tr[k].l+1);
tr[k].f+=w;
return ;
}
if(tr[k].f ) passdown(k);
int mid=(tr[k].r+tr[k].l)/2;
if(x<=mid) add(x,y,w,k*2);
if(y>mid) add(x,y,w,k*2+1);
tr[k].v=tr[k*2].v+tr[k*2+1].v;
}
int main()
{
cin>>n>>m;
creat(1,n,1);
char a;
int b,c,d;
for(int i=1;i<=m;i++)
{
sum=0;
cin>>a;
if(a=='Q')//query
{
scanf("%d %d",&b,&c);
if(b>c)
swap(b,c);
query(b,c,1);
cout<<sum<<endl;
}
else//add
{
scanf("%d %d %d",&b,&c,&d);
if(b>c)
swap(b,c);
add(b,c,d,1);
}
}
return 0;
}