线段树 区间加k
昨天大佬将了下线段树,当时觉得不难,但自己动起手来就漏洞百出。这里,就来理理线段树的思想。
本文围绕下面代码展开讲解:
#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cmath>
using namespace std;
#define ll long long
ll num[500005];
struct tre
{
ll l, r, v,color;
ll len;
}s[500005];
void find_work(ll a)
{
if (s[a].color == 0)return;
s[2 * a].v += s[a].color * s[2*a].len;
s[2 * a + 1].v += s[a].color * s[2 * a + 1].len;
s[2 * a].color += s[a].color;
s[2 * a + 1].color += s[a].color;
s[a].color = 0;
}
void build_tree(ll a,ll l,ll r)
{
s[a].l = l; s[a].r = r; s[a].len = r - l + 1;
if (s[a].l == s[a].r) { s[a].v = num[l]; return; }
ll mid = (s[a].l + s[a].r) / 2;
build_tree(2 * a, l, mid);
build_tree(2 * a + 1, mid + 1, r);
s[a].v = s[2 * a].v + s[2 * a + 1].v;
}
ll find_tree(ll a, ll l, ll r)
{
if (s[a].l >= l && s[a].r <= r)return s[a].v;
ll mid = (s[a].l + s[a].r) / 2; find_work(a);
ll ans = 0;
if (l <= mid)ans += find_tree(2*a, l, r);
if (r > mid)ans += find_tree(2 * a + 1, l, r);
return ans;
}
void add_tree(ll a, ll l, ll r, ll k)
{
if (s[a].l >= l && s[a].r <= r)
{
s[a].v += k * s[a].len;
s[a].color += k;
return;
}
ll mid = (s[a].l + s[a].r) / 2; find_work(a);
if (l<=mid) { add_tree(2 * a,l, r,k); }
if (r>mid) { add_tree(2 * a+1, l, r,k); }
s[a].v = s[2 * a].v + s[2 * a + 1].v;
}
int main()
{
ll n, m;
ll op, ans, kase, k;
cin >> n >> m;
for (int i = 1; i <= n; i++)
{
cin >> num[i];
}
build_tree(1, 1, n);
while (m--)
{
cin >> op >> ans >> kase;
if (op == 1)
{
cin >> k;
add_tree(1,ans,kase, k);
}
else
{
cout << find_tree(1, ans, kase) << endl;
}
}
return 0;
}
1.1 一步一步来,我们先聊聊建树的思想。
引用一下其他人对线段树的概念理解:
第一部 概念引入
线段树是一种二叉树,也就是对于一个线段,我们会用一个二叉树来表示。比如说一个长度为4的线段,我们可以表示成这样:
这是什么意思呢? 如果你要表示线段的和,那么最上面的根节点的权值表示的是这个线段1-4的和。根的两个儿子分别表示这个线段中1-2的和,与2-3的和。以此类推。
首先应该明确,一个树区间里应该有些什么东西,换种说法就是:我们需要用什么来表示任意一个区间?
那么这样问题就清晰了,
1.1 一个区间肯定有开始的位置:我们把开始设为 l
1.2 一个区间肯定有终止的位置:我们把结尾设为 r
2.1 一个区间的数量: 我们用 len 表示长度
3.1 一个区间的所有元素和:用 v 表示
4.1 既然有这么多个区间,那么我们需要给这些区间标个号,排个序:用 a 表示区间号码。
下面单独看建树的代码:
void build_tree(ll a,ll l,ll r)
{
s[a].l = l; s[a].r = r; s[a].len = r - l + 1;
if (s[a].l == s[a].r) { s[a].v = num[l]; return; }
ll mid = (s[a].l + s[a].r) / 2;
build_tree(2 * a, l, mid);
build_tree(2 * a + 1, mid + 1, r);
s[a].v = s[2 * a].v + s[2 * a + 1].v;
}
一进入建树,我们就可以确定左右端点,区间的长度,所以可以直接赋值:
s[a].l = l; s[a].r = r; s[a].len = r - l + 1;
一个区间的元素和等于两个小区间的元素和相加,也就是1=0.5+0.5。
而如果这个区间的元素只有一个的话就直接等于该元素的值。
所以:
if (s[a].l == s[a].r) { s[a].v = num[l]; return; }
建树是一种由大到小再由小到大的过程。可以理解成先微分再积分。
微分:
build_tree(2 * a, l, mid);
build_tree(2 * a + 1, mid + 1, r);
积分:
s[a].v = s[2 * a].v + s[2 * a + 1].v;
2.1 建完树就要知道如何对树进行裁剪。
void add_tree(ll a, ll l, ll r, ll k)
{
if (s[a].l >= l && s[a].r <= r)
{
s[a].v += k * s[a].len;
s[a].color += k;
return;
}
ll mid = (s[a].l + s[a].r) / 2; find_work(a);
if (l<=mid) { add_tree(2 * a,l, r,k); }
if (r>mid) { add_tree(2 * a+1, l, r,k); }
s[a].v = s[2 * a].v + s[2 * a + 1].v;
}
下面对这段代码做出解释:
if (s[a].l >= l && s[a].r <= r)
{
s[a].v += k * s[a].len;//对该段区间的值进行修改
s[a].color += k;//标记这段区间表示已经修改
return;
}
ll mid = (s[a].l + s[a].r) / 2; find_work(a);
if (l<=mid) { add_tree(2 * a,l, r,k); }
if (r>mid) { add_tree(2 * a+1, l, r,k); }
s[a].v = s[2 * a].v + s[2 * a + 1].v;
这里find_work函数的作用就是传递。
void find_work(ll a)
{
if (s[a].color == 0)return;//如果没有标记就返回
s[2 * a].v += s[a].color * s[2*a].len;//加数
s[2 * a + 1].v += s[a].color * s[2 * a + 1].len;//加数
s[2 * a].color += s[a].color;//标记传递
s[2 * a + 1].color += s[a].color;//标记传递
s[a].color = 0;//取消原标记
}
阶段一:
阶段二:
最后要聊聊对区间的和的输出:
ll find_tree(ll a, ll l, ll r)
{
if (s[a].l >= l && s[a].r <= r)return s[a].v;
ll mid = (s[a].l + s[a].r) / 2; find_work(a);
ll ans = 0;
if (l <= mid)ans += find_tree(2*a, l, r);
if (r > mid)ans += find_tree(2 * a + 1, l, r);
return ans;
}
关于这段代码,如果看懂了上面我写的自然就不是问题(虽然我将的很水,但是相信你们会懂的)
这里很关键的点在于这个find_work(a),大家要细评!
最后要聊一点易错的东西:线段树主要是利用不断的二分,递归来完成的,所以在一些细节上需要十分注意,比如 l<=mid,r>mid。这个等于号就很容易遗漏,而遗漏的后果不是无限递归就是答案错误。