线段树是基于分治思想的二叉树,用来维护区间信息(区间和、区间最值、区间gcd等),可以在logn的时间内执行区间修改和区间查询
线段树中每个叶子节点存储元素本身,非叶子节点存储区间内元素的统计值
递归建树
父亲节点编号为p
左孩子编号为2p,右孩子编号为2*p+1
模板
void build (int p,int l,int r)
{
tr[p].l = l , tr[p].r = r , tr[p].sum = a[l] , tr[p].add = 0 ;
if (l == r)
return ;
int m = l + r >> 1 ;
build (lc,l,m) ;
build (rc,m+1,r) ;
pushup(p) ;//tr[p].sum = tr[lc].sum + tr[rc].sum ;
}
点修改
从根节点进入,递归找到叶子节点[x,x] ,把该节点的值增加k。然后从下往上更新其祖先节点的统计值
模板
void update (int p,int x,int k)
{
if (tr[p].l == x && tr[p].r == x)
{
tr[p].sum += k ;
return ;
}
int m = tr[p].l + tr[p].r >> 1 ;
if (x <= m) update (lc,x,k) ;
if (x > m) update (rc,x,k) ;
pushup (p) ;//tr[p].sum = tr[lc].sum + tr[rc].sum ;
}
区间查询
拆分与拼凑的思想。
例如,查询区间[4,9]可以拆分成[4,5],[6,8]和[9,9] ,通过合并这三个区间的答案求得查询答案
步骤
从根节点进入,递归执行一下过程
-
若查询区间[x,y] 完全覆盖当前节点区间,则立即回溯,并返回该节点的sum值
-
若左节点与[x,y]有重叠,则递归访问左子树
-
若右节点与[x,y]有重叠,则递归访问右子树
模板
LL query (int p,int x,int y)
{
if (x <= tr[p].l && tr[p].r <= y)
return tr[p].sum ;
LL sum = 0 ;
int m = tr[p].l + tr[p].r >> 1 ;
pushdown(p) ;//下一个区间模板有该函数
if (x <= m) sum += query(lc,x,y) ;
if (y > m) sum += query(rc,x,y) ;
return sum ;
}
区间修改
懒惰修改
当[x,y]完全覆盖该节点[a,b]时,先修改该区间的sum值,再打上一个“懒标记”,然后立即返回。等下次需要时,在下传“懒标记”。这样,可以把每次修改和查询的时间都控制在O(logn)
模板
void pushup (int p)
{
tr[p].sum = tr[lc].sum + tr[rc].sum ;
}
void pushdown (int p)
{
if (tr[p].add)
{
tr[lc].sum += (tr[lc].r - tr[lc].l + 1)*tr[p].add ;
tr[rc].sum += (tr[rc].r - tr[rc].l + 1)*tr[p].add ;
tr[lc].add += tr[p].add ;
tr[rc].add += tr[p].add ;
tr[p].add = 0;
}
}
void update (int p,int x,int y,LL k)
{
if (x <= tr[p].l&&tr[p].r <= y)
{
tr[p].sum += (tr[p].r - tr[p].l + 1) * k ;
tr[p].add += k ;
return ;
}
pushdown (p) ;
int m = tr[p].l + tr[p].r >> 1 ;
if (x <= m) update (lc,x,y,k) ;
if (y > m) update (rc,x,y,k) ;
pushup (p) ;
}
模板题
P3372 【模板】线段树 1 - 洛谷 | 计算机科学教育新生态 (luogu.com.cn)
AC代码
#include <iostream>
using namespace std ;
#define lc p<<1
#define rc p<<1|1
typedef long long LL ;
const int N = 1e5+10 ;
int n , m ;
LL a[N] ;
struct node {
int l , r ;
LL sum , add ;
};
node tr[4*N] ;
void pushup (int p)
{
tr[p].sum = tr[lc].sum + tr[rc].sum ;
}
void build (int p,int l,int r)
{
tr[p].l = l , tr[p].r = r , tr[p].sum = a[l] , tr[p].add = 0 ;
if (l == r)
return ;
int m = l + r >> 1 ;
build (lc,l,m) ;
build (rc,m+1,r) ;
pushup(p) ;
}
void pushdown (int p)
{
if (tr[p].add)
{
tr[lc].sum += (tr[lc].r - tr[lc].l + 1)*tr[p].add ;
tr[rc].sum += (tr[rc].r - tr[rc].l + 1)*tr[p].add ;
tr[lc].add += tr[p].add ;
tr[rc].add += tr[p].add ;
tr[p].add = 0 ;
}
}
void update (int p,int x,int y,LL k)
{
if (x <= tr[p].l&&tr[p].r <= y)
{
tr[p].sum += (tr[p].r - tr[p].l + 1) * k ;
tr[p].add += k ;
return ;
}
pushdown (p) ;
int m = tr[p].l + tr[p].r >> 1 ;
if (x <= m) update (lc,x,y,k) ;
if (y > m) update (rc,x,y,k) ;
pushup (p) ;
}
LL query (int p,int x,int y)
{
if (x <= tr[p].l && tr[p].r <= y)
return tr[p].sum ;
LL sum = 0 ;
int m = tr[p].l + tr[p].r >> 1 ;
pushdown(p) ;
if (x <= m) sum += query(lc,x,y) ;
if (y > m) sum += query(rc,x,y) ;
return sum ;
}
int main ()
{
scanf ("%d%d",&n,&m) ;
for (int i = 1;i <= n;i++)
scanf ("%lld",&a[i]) ;
build (1,1,n) ;
while (m--)
{
int c ;
scanf ("%d",&c) ;
if (c == 1)
{
int x , y ;
LL k ;
scanf ("%d%d%lld",&x,&y,&k) ;
update (1,x,y,k) ;
}
else
{
int x , y ;
scanf ("%d%d",&x,&y) ;
LL res = query (1,x,y) ;
cout << res << endl ;
}
}
return 0 ;
}