什么是线段树
假设有编号从1到n的n个点,每个点都存了一些信息,用[L,R]表示下标从L到R的这些点。
线段树的用处就是,对编号连续的一些点进行修改或者统计操作,修改和统计的复杂度都是O(log2(n)).
线段树的原理,就是,将[1,n]分解成若干特定的子区间(数量不超过4*n)(因此空间也要开四倍),然后,将每个区间[L,R]都分解为少量特定的子区间,通过对这些少量子区间的修改或者统计,来实现快速对[L,R]的修改或者统计。
由此看出,用线段树统计的东西,必须是左右能够合并的。
(区间求和,区间最大值,区间最大子串和等等)
线段树的节点
ps:出于方便之后的ls和rs都是宏定义的左右儿子
#define ls now<<1
#define rs now<<1|1
struct node {
int sum;//这个节点要维护的东西,这里就是求和
int l, r;//这个节点维护的左右区间
}tree[N<<2];//4倍空间
图例: 下面是一个维护8个节点的简单线段树
每个节点代表的是区间的和。
线段树的实现十分简单,主要是看维护的方法(push_up),比如说你想要求一段区间的和,那么push_up就是要把父亲的大小等于两个儿子的大小之和。
void push_up(int now)
{
tree[now].sum = tree[ls].sum + tree[rs].sum;//左儿子加右儿子
}
我们可以通过修改push_up来让我们下面的操作维护其他东西
举个例子:你如果要找区间最大值,那就修改成左右儿子取最大。
tree[now].maxn=max(tree[ls].maxn,tree[rs].maxn);
相应的查询操作下面部分稍微修改即可
然后就是构建(build),更新(update),查询(query)3个基本操作,基本上都是复读机,属于是粘贴复制改一下细节就行。
ps:这里的更新操作是单点更新,要区间更新的可以去看下面。
如果叶子节点没有初始值(为0)的话,也要写build函数,因为我的写法是结构体表示区间,所以build还有初始化区间的作用,如果是用函数里的参数来表示的话,可以不写build。(个人习惯吧)
void build(int now, int l, int r)
{
tree[now].l = l; tree[now].r = r;//初始化维护的区间
if (l == r)
{
cin>>tree[now].sum;//到了叶子节点就输入要的数
return;
}
int mid = (l + r) >> 1;//线段树每次分一半
build(ls, l, mid); build(rs, mid + 1, r);//递归构建
push_up(now);//更新父亲节点
}
void update(int now, int k, int v)//要更新的节点k,要更新的数量v
{
int l = tree[now].l, r = tree[now].r;
if (l == r)
{
tree[now].sum = v;
return;
}
int mid = (l + r) >> 1;
if (k <= mid)update(ls, k, v);//在左边就更新左边
else update(rs, k, v);//否则更新右边
push_up(now);
}
int query(int now, int L, int R)//注意这里是要查询的区间L,R,不是now表示的区间
{
int l = tree[now].l, r = tree[now].r;
if (L <= l && r <= R)//满足区间
{
return tree[now].sum;
}
int ans = 0;
int mid = (l + r) >> 1;
if (L <= mid)ans+=query(ls, L, R);//找左边
if (mid < R)ans+=query(rs, L, R);//找右边,注意不是else,因为是区间,两边都有可能
return ans;
}
push_down(区间操作)
区间更新的时候我们如果像查询的时候一样找到就返回,那么下面的值就不会更新,如果我们直接每个数都修改到底的话,复杂度又会太大。所以我们还要在节点里加一个懒标记(tag)来表示我这个节点的下面部分都要修改这个量,只是我先不传递(push_down)。这样可以保证一次修改复杂度在 o ( log n ) o(\log n) o(logn)
void push_down(int now)
{
if (tree[now].tag)
{
tree[ls].tag += tree[now].tag;//下传标记
tree[rs].tag += tree[now].tag;//下传标记
int mid = (tree[now].l + tree[now].r) >> 1;
tree[ls].sum += (mid - tree[now].l + 1) * tree[now].tag;//全体都加
tree[rs].sum += (tree[now].rt - mid) * tree[now].tag;
tree[now].tag = 0;//下传完就情空
}
}
void update(int now, int L,int R, int v)//要更新的节点k,要加上的值v
{
int l = tree[now].l, r = tree[now].r;
if (L <= l && r <= R)//满足区间
{
tree[now].sum += v*(r-l+1);//区间加法 数组长度*要加的数
tree[now].tag += v;//打个标记,表示儿子以后都要加
return;
}
int mid = (l + r) >> 1;
push_down(now);//如果有标记就要下传
if (L <= mid)update(ls, L, R, v);//找左边
if (mid < R)update(rs, L, R, v);
push_up(now);
}
int query(int now, int L, int R)//注意这里是要查询的区间L,R,不是now表示的区间
{
int l = tree[now].l, r = tree[now].r;
if (L <= l && r <= R)//满足区间
{
return tree[now].sum;
}
int ans = 0;
push_down(now);//如果有标记就要下传
int mid = (l + r) >> 1;
if (L <= mid)ans += query(ls, L, R);//找左边
if (mid < R)ans += query(rs, L, R);//找右边,注意不是else,因为是区间,两边都有可能
return ans;
}
遇到的一些区间操作
只放了下传,合并,更新,其他都挺模板的。
先乘后加
struct node {
ll v, add, mul;
}tree[MAXN << 2];
void push_up(int now)
{
tree[now].v = (tree[ls].v + tree[rs].v) % mod;
}
void push_down(int now, int l, int r)
{
int mid = (l + r) >> 1;
tree[ls].v = (tree[ls].v * tree[now].mul + tree[now].add * (mid - l + 1)) % mod;
tree[rs].v = (tree[rs].v * tree[now].mul + tree[now].add * (r - mid)) % mod;
tree[ls].mul = (tree[ls].mul * tree[now].mul) % mod;
tree[rs].mul = (tree[rs].mul * tree[now].mul) % mod;
tree[ls].add = (tree[ls].add * tree[now].mul + tree[now].add) % mod;
tree[rs].add = (tree[rs].add * tree[now].mul + tree[now].add) % mod;
tree[now].add = 0; tree[now].mul = 1;
}
void update1(int now, int l, int r, int x, int y, ll k)//加法
{
if (x <= l && r <= y)
{
tree[now].add = (tree[now].add + k) % mod;
tree[now].v = (tree[now].v + k * (r - l + 1)) % mod;
return;
}
push_down(now, l, r);
int mid = (l + r) >> 1;
if (x <= mid)update1(now << 1, l, mid, x, y, k);
if (mid < y)update1(now << 1 | 1, mid + 1, r, x, y, k);
push_up(now);
}
void update2(int now, int l, int r, int x, int y, ll k)//乘法
{
if (x <= l && r <= y)
{
tree[now].mul = (tree[now].mul * k) % mod;
tree[now].add = (tree[now].add * k) % mod;
tree[now].v = (tree[now].v * k) % mod;
return;
}
push_down(now, l, r);
int mid = (l + r) >> 1;
if (x <= mid)update2(now << 1, l, mid, x, y, k);
if (mid < y)update2(now << 1 | 1, mid + 1, r, x, y, k);
push_up(now);
}
区间向下取根号取整
struct node {
ll mx, mn;//最大值最小值
int l, r;
ll sum, tag;
}tree[maxn << 2];
void push_up(int now)
{
tree[now].sum = tree[ls].sum + tree[rs].sum;
tree[now].mx = max(tree[ls].mx, tree[rs].mx);
tree[now].mn = min(tree[ls].mn, tree[rs].mn);
}
void push_down(int now)
{
if (tree[now].tag == -1)return;
tree[ls].tag = tree[rs].tag = tree[now].tag;
tree[ls].mx = tree[rs].mx = tree[ls].mn = tree[rs].mn = tree[now].tag;
int mid = (tree[now].l + tree[now].r) >> 1;
tree[ls].sum = (mid - tree[now].l + 1) * tree[now].tag;
tree[rs].sum = (tree[now].r - mid) * tree[now].tag;
tree[now].tag = -1;
}
void build(int now, int l, int r)
{
tree[now].mx = -1e9; tree[now].mn = 1e9;
tree[now].l = l; tree[now].r = r; tree[now].tag = -1;
if (l == r)
{
ll x;cin>>x;
tree[now].sum = tree[now].mn = tree[now].mx = x;
return;
}
int mid = (l + r) >> 1;
build(ls, l, mid);
build(rs, mid + 1, r);
push_up(now);
}
void update(int now, int x, int y)
{
ll a1 = sqrt(tree[now].mx);
ll a2 = sqrt(tree[now].mn);
if (a1 == a2 && x <= tree[now].l && tree[now].r <= y)
{
tree[now].tag = a1;
tree[now].mx = tree[now].mn = a1;
tree[now].sum = (tree[now].r - tree[now].l + 1) * (a1);
return;
}
int mid = (tree[now].l + tree[now].r) >> 1;
push_down(now);
if (x <= mid)update(ls, x, y);
if (mid < y)update(rs, x, y);
push_up(now);
}
ll query(int now, int L, int R)
{
int l = tree[now].l, r = tree[now].r;
if (L <= l && r <= R)
{
return tree[now].sum;
}
ll ans = 0;
int mid = (l + r) >> 1;
if (L <= mid)ans += query(ls, L, R);
if (mid < R)ans += query(rs, L, R);
return ans;
}
等差数列
给定 l , r , k , d l,r,k,d l,r,k,d,要求给区间 [ l , r ] [l,r] [l,r]的每一个数加上 a i = a i + ( k + ( i − l ) ∗ d ) a_i=a_i+(k+(i-l)*d) ai=ai+(k+(i−l)∗d)的等差数列。
我们就定义两个标记,第一个表示区间最左边的数的公差d,第二个表示整个区间要加的值。
下传操作:区间加没问题,主要是公差怎么下传。
来看一个例子:比如
[
1
,
5
]
[1,5]
[1,5]tag1为1的话,相当于要加上
[
1
,
5
]
[1,5]
[1,5]加上1,2,3,4,5;
那分到下面的区间
[
1
,
3
]
[1,3]
[1,3]和
[
4
,
5
]
[4,5]
[4,5]我们发现左半部分是可以不用处理的,直接标记下传加上总和值就行。
而右半段的4,5可以看成1+3,2+3。这样我们下传的公差就还是1,只是右区间整个区间还要多加上前面左区间最右边的数。
struct node{
int l,r;
ll tag1,tag2;
ll sum;
}tree[N<<2];
void push_up(int now)
{
tree[now].sum=tree[ls].sum+tree[rs].sum;
}
void push_down(int now)
{
if(tree[now].tag1||tree[now].tag2)
{
int l=tree[now].l,r=tree[now].r;
int mid=(l+r)>>1;
ll llen=mid-l+1,rlen=r-mid;
tree[ls].tag1+=tree[now].tag1;
tree[rs].tag1+=tree[now].tag1;
tree[ls].sum+=tree[now].tag1*(llen+1)*(llen)/2;
tree[rs].sum+=tree[now].tag1*(rlen+1)*(rlen)/2;
tree[rs].tag2+=tree[now].tag1*llen;
tree[rs].sum+=tree[now].tag1*llen*rlen;
tree[ls].tag2+=tree[now].tag2;
tree[rs].tag2+=tree[now].tag2;
tree[ls].sum+=tree[now].tag2*llen;
tree[rs].sum+=tree[now].tag2*rlen;
tree[now].tag2=tree[now].tag1=0;
}
}
void update(int now,int x,int y,ll k,ll d)
{
if(tree[now].l>y||tree[now].r<x)return;
if(x<=tree[now].l&&tree[now].r<=y)
{
tree[now].tag1+=d;
tree[now].tag2+=d*(tree[now].l-x)+k;
ll len=tree[now].r-tree[now].l+1;
tree[now].sum+=(len+1)*len/2*d;
tree[now].sum+=len*(d*(tree[now].l-x)+k);
return;
}
push_down(now);
update(ls,x,y,k,d);
update(rs,x,y,k,d);
push_up(now);
}