@[K ACMer]
题意:
支持两个操作:
- 求任意任意区间元素的和
- 对任意区间更新元素的值
分析:
就是对线段树进行区间更新的操作,用一个data数组来延迟下推,只有当查到这里的时候才计算.这样让区间更新的复杂度变为了
log(n)
其实只是对区间完全包含的情况就把区间暂存在这里不蔓延到儿子,查寻的时候来查即可.
也可以用经典的树状数组来做:
__树状数组除了最显然的是维护区间前n项和意外还有就是增加一个数x,其后面的值都增加x.即对后缀区间的增加__**这其实是一个前n项和的性质,对前n项和的更新,等价于对n后面的说有数都加上这个更新值.**
这里用两个树状数组来实现区间更新.充分体现了BIT是对整个后缀增加同一个值这个特性.并利用这个特性来构造出了区间更新.
线段树的Code:
#include <cstdio>
#include <iostream>
#include <cstring>
using namespace std;
typedef long long ll;
ll data[300009], datb[300009];
int n, m;
void add(int a, int b, int x, int k, int l, int r) {
if (a <= l && b >= r) data[k] += x;
else if (l < b && a < r){
datb[k] += (min(r, b) - max(l, a)) * x;
add(a, b, x, k * 2 + 1, l, (l + r) / 2);
add(a, b, x, k * 2 + 2, (l + r) / 2, r);
}
}
ll sum(int a, int b, int k, int l, int r) {
if (a >= r || b <= l) return 0;
else if (a <= l && b >= r) return data[k] * (r - l) + datb[k];
else {
ll res = data[k] * (min(b, r) - max(l, a));
res += sum(a, b, k * 2 + 1, l, (l + r) / 2);
res += sum(a, b, k * 2 + 2, (l + r) / 2, r);
return res;
}
}
int main(void) {
while (~scanf("%d%d", &n, &m)) {
memset(data, 0, sizeof(data));
memset(datb, 0, sizeof(datb) );
for (int i = 0; i < n; i++) {
int x;
scanf("%d", &x);
add(i, i + 1, x, 0, 0, n);
}
for (int i = 0; i < m; i++) {
char c;
cin >> c;
if (c == 'Q') {
int l, r;
scanf("%d%d", &l, &r);
printf("%I64d\n", sum(l - 1, r, 0, 0, n));
} else {
int l, r, x;
scanf("%d%d%d", &l, &r, &x);
add(l - 1, r , x, 0, 0, n);
}
}
}
return 0;
}
BIT的Code:
#include <cstdio>
#include <iostream>
#include <cstring>
using namespace std;
typedef long long ll;
int n, m;
ll bita[100009], bitb[100009];
ll sum(ll* b, int i) {
ll ret = 0;
while ( i > 0) {
ret += b[i];
i -= i & -i;
}
return ret;
}
void add(ll* b, int k, int a) {
while (k <= n) {
b[k] += a;
k += k & -k;
}
}
int main(void) {
while (~scanf("%d%d", &n, &m)) {
memset(bitb, 0, sizeof(ll) * (n + 5));
memset(bita, 0, sizeof(ll) * (n + 5));
for (int i = 0; i < n; i++) {
int x;
scanf("%d", &x);
add(bita, i + 1, x);
}
for (int i = 0; i < m; i++) {
char c;
cin >> c;
if (c == 'Q') {
int l, r;
scanf("%d%d", &l, &r);
ll ans = 0;
ans += sum(bita, r) + sum(bitb, r) * r;
ans -= sum(bita, l - 1) + sum(bitb, l - 1) * (l - 1);
printf("%lld\n", ans);
} else {
int l, r, x;
scanf("%d%d%d", &l, &r, &x);
add(bita, l, -x * (l - 1));
add(bita, r + 1, x * r);
add(bitb, r + 1, -x);
add(bitb, l, x);
}
}
}
return 0;
}