线段树
以下部分文本摘自百度。
1.简介
线段树是一种二叉搜索树,与区间树相似,它将一个区间划分成一些单元区间,每个单元区间对应线段树中的一个叶结点。
使用线段树可以快速的查找某一个节点在若干条线段中出现的次数,时间复杂度为 O ( log N ) O(\log N) O(logN) 。而未优化的空间复杂度为 2 N 2N 2N ,实际应用时一般还要开空间为 4 N 4N 4N 的数组以免越界,因此有时需要离散化让空间压缩。
2.定义
线段树是一种二叉搜索树,与区间树相似,它将一个区间划分成一些单元区间,每个单元区间对应线段树中的一个叶结点。
对于线段树中的每一个非叶子节点 [ a , b ] [a,b] [a,b] ,它的左儿子表示的区间为 [ a , ( a + b ) / 2 ] [a,(a+b)/2] [a,(a+b)/2] ,右儿子表示的区间为 [ ( a + b ) / 2 + 1 , b ] [(a+b)/2+1,b] [(a+b)/2+1,b] 。因此线段树是平衡二叉树,最后的子节点数目为 N N N ,即整个线段区间的长度。
使用线段树可以快速的查找某一个节点在若干条线段中出现的次数,时间复杂度为 O ( log N ) O(\log N) O(logN) 。而未优化的空间复杂度为 2 N 2N 2N ,实际应用时一般还要开空间为 4 N 4N 4N 的数组以免越界,因此有时需要离散化让空间压缩。
3.基本结构
线段树是建立在线段的基础上,每个结点都代表了一条线段 [ a , b ] [a,b] [a,b] 。长度为1的线段称为元线段。非元线段都有两个子结点,左结点代表的线段为 [ a , ( a + b ) / 2 ] [a,(a + b) / 2] [a,(a+b)/2],右结点代表的线段为 [ ( ( a + b ) / 2 ) + 1 , b ] [((a + b) / 2)+1,b] [((a+b)/2)+1,b] 。
长度范围为 [ 1 , L ] [1,L] [1,L] 的一棵线段树的深度为 log L + 1 \log L + 1 logL+1 。这个显然,而且存储一棵线段树的空间复杂度为 O ( L ) O(L) O(L) 。
线段树支持最基本的操作为插入和删除一条线段。下面以插入为例,详细叙述,删除类似。
将一条线段 [ a , b ] [a,b] [a,b] 插入到代表线段 [ l , r ] [l,r] [l,r] 的结点 p p p 中,如果 p p p 不是元线段,那么令 m i d = ( l + r ) / 2 mid=(l+r)/2 mid=(l+r)/2 。如果 b < m i d b<mid b<mid ,那么将线段 [ a , b ] [a,b] [a,b] 也插入到 p p p 的左儿子结点中,如果 a > m i d a>mid a>mid ,那么将线段 [ a , b ] [a,b] [a,b] 也插入到 p p p 的右儿子结点中。
插入(删除)操作的时间复杂度为 O ( log n ) O(\log n) O(logn) 。
4.实际应用
上面的都是些基本的线段树结构,但只有这些并不能做什么,就好比一个程序有输入没输出,根本没有任何用处。
最简单的应用就是记录线段是否被覆盖,随时查询当前被覆盖线段的总长度。那么此时可以在结点结构中加入一个变量 int count;
,代表当前结点代表的子树中被覆盖的线段长度和。这样就要在插入(删除)当中维护这个 count
值,于是当前的覆盖总值就是根节点的 count
值了。
另外也可以将 int count;
换成 bool cover;
支持查找一个结点或线段是否被覆盖。
实际上,通过在结点上记录不同的数据,线段树还可以完成很多不同的任务。例如,如果每次插入操作是在一条线段上每个位置均加 k
,而查询操作是计算一条线段上的总和,那么在结点上需要记录的值为 sum
。
这里会遇到一个问题:为了使所有sum值都保持正确,每一次插入操作可能要更新
O
(
N
)
O(N)
O(N) 个 sum
值,从而使时间复杂度退化为
O
(
N
)
O(N)
O(N) 。
解决方案是 Lazy 思想:对整个结点进行的操作,先在结点上做标记,而并非真正执行,直到根据查询操作的需要分成两部分。
根据 Lazy 思想,我们可以在不代表原线段的结点上增加一个值 toadd
,即为对这个结点,留待以后执行的插入操作 k
值的总和。对整个结点插入时,只更新 sum
和 toadd
值而不向下进行,这样时间复杂度可证明为
O
(
log
N
)
O(\log N)
O(logN) 。
对一个 toadd
值为
0
0
0 的结点整个进行查询时,直接返回存储在其中的 sum
值;而若对 toadd
不为
0
0
0 的一部分进行查询,则要更新其左右子结点的 sum
值,然后把 toadd
值传递下去,再对这个查询本身,左右子结点分别递归下去。时间复杂度也是
O
(
N
log
N
)
O(N \log N)
O(NlogN) 。
5.基本代码
P3372
(学习了皎月半洒花的代码)
#include <bits/stdc++.h>
#define int long long
#define us unsigned
using namespace std;
inline int read()
{
int ww = 0,ee = 1;
char ch = getchar();
while (ch < '0' || ch > '9')
{
if (ch == '-')
{
ee = -1;
}
ch = getchar();
}
while (ch >= '0' && ch <= '9')
{
ww = ww * 10 + ch - '0';
ch = getchar();
}
return ww * ee;
}
us int n,m;
const int N = 4000010;
us int a[N];
us int ans[N],tag[N];
inline int ls(int a) //计算左儿子,同a*2
{
return (a << 1);
}
inline int rs(int a) //计算右儿子,同a*2+1
{
return (a << 1 | 1);
}
int op,opx,opy,opz;
void push_up(int x) //记录 x 节点的值
{
ans[x] = ans[ls(x)] + ans[rs(x)];
}
void build(int root,int l,int r)
{
tag[root] = 0; //tag数组为懒标记
if (l == r) //叶(子)节点,记录节点值,返回
{
ans[root] = a[l];
return ;
}
//不是叶(子)节点,继续往后搜
int mid = (l+r) / 2;
build(ls(root),l,mid);
build(rs(root),mid+1,r);
push_up(root); //记录节点值
}
void f(int x,int l,int r,int k)
{
//记录当前节点所代表的区间
ans[x] = ans[x] + k * (r-l+1);
//记录节点值
tag[x] = tag[x] + k;
//懒标记
}
void push_down(int x,int l,int r)
{
int mid = (l+r) >> 1;
f(ls(x),l,mid,tag[x]);
f(rs(x),mid+1,r,tag[x]);
tag[x] = 0;
//每次更新两个子节点,不断向下传递,传递后清空
}
void update(int xl,int xr,int l,int r,int x,int k)
{
//xl,xp为要修改的区间
//x为当前节点
//l,r为x节点所存储的区间
if (xl <= l && r <= xr)
{
tag[x] += k;
ans[x] += k * (r-l+1);
return ;
}
push_down(x,l,r);
//向下传递
int mid = (l+r) / 2;
if (xl <= mid)
{
update(xl,xr,l,mid,ls(x),k);
}
if (xr > mid)
{
update(xl,xr,mid+1,r,rs(x),k);
}
push_up(x);
}
int query(int qx,int qy,int l,int r,int x)
{
int ret = 0;
if (qx <= l && r <= qy)
{
return ans[x];
}
int mid = (l+r) / 2;
push_down(x,l,r);
if (qx <= mid)
{
ret += query(qx,qy,l,mid,ls(x));
}
if (qy > mid)
{
ret += query(qx,qy,mid+1,r,rs(x));
}
return ret;
}
signed main()
{
n = read();
m = read();
for (int i = 1;i <= n;i++)
{
a[i] = read();
}
build(1,1,n);
//建立 根节点为1,区间为1-n的树
while (m--)
{
op = read();
if (op == 1)
{
opx = read();
opy = read();
opz = read();
update(opx,opy,1,n,1,opz);
}
else
{
opx = read();
opy = read();
printf("%lld\n",query(opx,opy,1,n,1));
}
}
return 0;
}
无注释版
#include <bits/stdc++.h>
#define int long long
#define us unsigned
using namespace std;
inline int read()
{
int ww = 0,ee = 1;
char ch = getchar();
while (ch < '0' || ch > '9')
{
if (ch == '-')
{
ee = -1;
}
ch = getchar();
}
while (ch >= '0' && ch <= '9')
{
ww = ww * 10 + ch - '0';
ch = getchar();
}
return ww * ee;
}
us int n,m;
const int N = 4000010;
us int a[N];
us int ans[N],tag[N];
inline int ls(int a)
{
return (a << 1);
}
inline int rs(int a)
{
return (a << 1 | 1);
}
int op,opx,opy,opz;
void push_up(int x)
{
ans[x] = ans[ls(x)] + ans[rs(x)];
}
void build(int root,int l,int r)
{
tag[root] = 0;
if (l == r)
{
ans[root] = a[l];
return ;
}
int mid = (l+r) / 2;
build(ls(root),l,mid);
build(rs(root),mid+1,r);
push_up(root);
}
void f(int x,int l,int r,int k)
{
ans[x] = ans[x] + k * (r-l+1);
tag[x] = tag[x] + k;
}
void push_down(int x,int l,int r)
{
int mid = (l+r) >> 1;
f(ls(x),l,mid,tag[x]);
f(rs(x),mid+1,r,tag[x]);
tag[x] = 0;
}
void update(int xl,int xr,int l,int r,int x,int k)
{
if (xl <= l && r <= xr)
{
tag[x] += k;
ans[x] += k * (r-l+1);
return ;
}
push_down(x,l,r);
int mid = (l+r) / 2;
if (xl <= mid)
{
update(xl,xr,l,mid,ls(x),k);
}
if (xr > mid)
{
update(xl,xr,mid+1,r,rs(x),k);
}
push_up(x);
}
int query(int qx,int qy,int l,int r,int x)
{
int ret = 0;
if (qx <= l && r <= qy)
{
return ans[x];
}
int mid = (l+r) / 2;
push_down(x,l,r);
if (qx <= mid)
{
ret += query(qx,qy,l,mid,ls(x));
}
if (qy > mid)
{
ret += query(qx,qy,mid+1,r,rs(x));
}
return ret;
}
signed main()
{
n = read();
m = read();
for (int i = 1;i <= n;i++)
{
a[i] = read();
}
build(1,1,n);
while (m--)
{
op = read();
if (op == 1)
{
opx = read();
opy = read();
opz = read();
update(opx,opy,1,n,1,opz);
}
else
{
opx = read();
opy = read();
printf("%lld\n",query(opx,opy,1,n,1));
}
}
return 0;
}