一、题目大意
我们有N个数字,多次区间进行修改和区间求和,每次求和时输出结果。
二、解题思路
最适合处理区间操作的数据结构是线段树和树状数组,它们的复杂性都是O(n*logn),引用《挑战程序设计》中的一段话,“比起线段树,Bit实现起来更方便,速度也更快”。所以更推荐树状数组。
但是树状数组适合于单点修改和区间查询,不适合对区间内容进行修改。所以我们要把区间修改的问题转化成单点修改的问题。
假设 [1 , i] 的区间和为 S(i),我们来考虑下,当[L , R]区间每个元素都加大x,对于S(i)的影响
if i < L
S(i) = S(i)
else if i >= L && i<=R
S(i) = S(i) + (i - L + 1) * x
= S(i) + i * x - L * x + 1 * x
= S(i) + i * x + (L * x - 1 * x)
= S(i) + i * x + (L - 1) * x
else if i > R
S(i) = S(i) + (R - L + 1) * x
= S(i) + R * x - L * x + 1 * x
= S(i) + R * x + (L * x - 1 * x)
= S(i) + R * x + (L - 1) * x
那么我们可以定义两个数组 bit0 和 bit1,当 [L ,R] 区间进行区间变化时,对于小于L的i不用处理,当L <= i时,我们发现这个值与 x 有关,所以我们需要记录下这个 x,可以将所有的 x 放在 bit1数组中,当[L , R]加x时,bit1[L]+=x,同时我们再对bit0[L]-=(L-1)*x,当i∈[L , R],不难看出sum(bit0,i) + sum(bit1,i)*i = S(i) + (L - 1) * x + i * x,即可计算 [1 , i]的和。
当 i > R时,其实这段区间内的变化对 S(i)而言只是加了一个 R * x + (L - 1)*x。因为我们是对bit0是从1开始区间求和,同时之前对 bit0[L]-=(L - 1)*x,所以当我们只需要对 R + 1的位置,修改bit1设置bit[R + 1]+=R*x,然后因为 i > L 时,已经没有 i * x了,所以 bit1[R+1]-=x来抵消掉前面bit[L]的x
这样对于i∈[R+1 , ∞],就可以通过bit1和bit0的前i项求和,达到 S(i)=S(i) + R * x - (L - 1)*x的效果。
这样的话,更新的操作就可以简化了,[L , R]区间加x分为以下4步
1)bit0[L]-=(L-1)*(x)
2)bit0[R + 1]+=R*x
3)bit1[L]+=x
4)bit1[R+1]-=x
同时这四步骤操作的bit需要经常计算前i项和,所以我们再操作某一个节点后,对于这四步骤,每一次都一定要及时的更新父亲节点( i = i + (i & (-i)直到 n 为止,每一个 i 都更新)
查询的时候就是查询出
1、bit0数组前 R 项的和 + bit1数组前R项的和 * R
2、bit0数组前 L - 1 项的和 + bit1数组前L-1项的和 * (L-1)
然后初始化的时候,因为不涉及到区间操作,那么就可以只更新bit0的第i位为输入的第i个数字即可(记得更新父亲节点),
bit我习惯初始化时候让n为2的次幂,这样我个人觉得规范!
三、代码
#include <iostream>
using namespace std;
typedef long long ll;
int num[100007], n, n_;
ll bit0[131080], bit1[131080];
void input()
{
for (int i = 1; i <= n_; i++)
{
scanf("%d", &num[i]);
}
}
void init()
{
n = 1;
while (n < n_)
{
n = n * 2;
}
for (int i = 0; i <= n; i++)
{
bit0[i] = 0LL;
bit1[i] = 0LL;
}
}
void updateBit1(int r, ll x)
{
if (r <= 0)
{
return;
}
for (int i = r; i <= n; i = i + (i & (-i)))
{
bit1[i] = bit1[i] + x;
}
}
void updateBit0(int r, ll x)
{
if (r <= 0)
{
return;
}
for (int i = r; i <= n; i = i + (i & (-i)))
{
bit0[i] = bit0[i] + x;
}
}
ll queryBit0(int r)
{
ll sum = 0LL;
for (int i = r; i > 0; i = i - (i & (-i)))
{
sum = sum + bit0[i];
}
return sum;
}
ll queryBit1(int r)
{
ll sum = 0LL;
for (int i = r; i > 0; i = i - (i & (-i)))
{
sum = sum + bit1[i];
}
return sum;
}
void push()
{
for (int i = 1; i <= n_; i++)
{
ll llNumI = ((ll)num[i]);
updateBit0(i, llNumI);
}
}
int main()
{
int l, r, v, q;
char c;
while (~scanf("%d%d", &n_, &q))
{
input();
init();
push();
while (q--)
{
scanf("\n%c", &c);
if (c == 'Q')
{
scanf("%d%d", &l, &r);
ll allAmt = queryBit0(r);
ll allAdd = queryBit1(r) * ((ll)r);
ll leftAmt = queryBit0(l - 1);
ll leftAdd = queryBit1(l - 1) * ((ll)(l - 1));
ll result = allAmt + allAdd - leftAmt - leftAdd;
printf("%lld\n", result);
}
else if (c == 'C')
{
scanf("%d%d%d", &l, &r, &v);
ll llV = ((ll)v);
updateBit1(l, llV);
updateBit1(r + 1, (-1LL) * llV);
updateBit0(l, (-1LL) * ((ll)(l - 1)) * llV);
updateBit0(r + 1, ((ll)r) * llV);
}
}
}
return 0;
}