题目:给一个数组A1,A2,…An,两个操作:对[a,b]区间内的每个数加C,查询[a,b]区间的和。
也就是实现能够高效区间更新和求和的线段树和树状数组。
线段树的每个结点维护两个值,一个是对区间内所有数都加上的值t,这样在区间更新时不必再向下更新。另一个是除t外区间内所有数的和s。若结点k的对应区间为[l,r),则区间内所有数的和为s+t*(r-l)。
能够实现高效区间更新的树状数组则比较有意思,树状数组对单个值的更新为O(logn),求区间和为O(logn)。刚看到这种区间更新的解法时还感叹了一番…
s(i)为未更新前A1+A2+…+Ai
更新操作为[l,r]区间每个值加x
s’(i)为更新后的A1+A2+…+Ai
可以推出
s′(i)=⎧⎩⎨s(i)s(i)+x∗(i−l+1)=x∗i+s(i)+x∗(−l+1)s(i)+x(r−l+1)=s(i)+x∗r+x∗(−l+1)i<ll≤i≤rr<i
用一个树状数组bit1来维护i前面的系数,用另一个树状数组bit0来维护没和i相乘的部分。区间更新操作只需在端点附近进行操作。更新后的s’(i)=sum(bit1,i)*i+sum(bit0,i)
// segment tree
#include <iostream>
#include <vector>
#include <string>
#include <stack>
#include <queue>
#include <set>
#include <map>
#include <algorithm>
#include <functional>
#include <cstdio>
#include <cstdlib>
#include <cmath>
#include <cstring>
using namespace std;
#define MAX_N 100005
typedef long long ll;
const int ST_SIZE = (1 << 18) - 1;
int N, Q;
int A[MAX_N];
ll s[ST_SIZE], t[ST_SIZE];
void init(int k, int l, int r) {
if (l == r - 1) {
s[k] = A[l];
} else {
int lc = 2 * k + 1, rc = 2 * k + 2;
int m = (l + r) / 2;
init(lc, l, m);
init(rc, m, r);
s[k] = s[lc] + s[rc];
}
t[k] = 0;
}
// [a, b) [l, r)
void update(int a, int b, int c, int k, int l, int r) {
if (b <= l || r <= a) return;
if (a <= l && r <= b) {
t[k] += c;
} else {
int lc = 2 * k + 1, rc = 2 * k + 2;
int m = (l + r) / 2;
update(a, b, c, lc, l, m);
update(a, b, c, rc, m, r);
s[k] += c * (min(b, r) - max(a, l));
}
}
ll query(int a, int b, int k, int l, int r) {
ll res = 0;
if (b <= l || r <= a) return res;
if (a <= l && r <= b) {
res += s[k];
} else {
int lc = 2 * k + 1, rc = 2 * k + 2;
int m = (l + r) / 2;
res += query(a, b, lc, l, m);
res += query(a, b, rc, m, r);
}
res += t[k] * (min(b, r) - max(a, l));
return res;
}
void print_tree(int k, int l, int r) {
printf("%d %d %d %lld %lld\n", k, l, r, s[k], t[k]);
if (l == r - 1) return;
int lc = 2 * k + 1, rc = 2 * k + 2;
int m = (l + r) / 2;
print_tree(lc, l, m);
print_tree(rc, m, r);
}
int main() {
scanf("%d %d", &N, &Q);
for (int i = 1; i <= N; i++) scanf("%d", A + i);
init(0, 1, N + 1);
for (int i = 0; i < Q; i++) {
char cmd[2];
scanf("%s", cmd);
if (cmd[0] == 'C') {
int a, b, c; scanf("%d %d %d", &a, &b, &c);
update(a, b + 1, c, 0, 1, N + 1);
} else {
int a, b; scanf("%d %d", &a, &b);
// print_tree(0, 1, N + 1);
ll res = query(a, b + 1, 0, 1, N + 1);
printf("%lld\n", res);
}
}
return 0;
}
// binary indexed tree
#include <iostream>
#include <vector>
#include <string>
#include <stack>
#include <queue>
#include <set>
#include <map>
#include <algorithm>
#include <functional>
#include <cstdio>
#include <cstdlib>
#include <cmath>
#include <cstring>
using namespace std;
#define MAX_N 100005
typedef long long ll;
int N, Q;
ll bit1[MAX_N], bit0[MAX_N];
void add(ll* bit, int i, int x) {
while (i <= N) {
bit[i] += x;
i += i & -i;
}
}
ll sum(ll* bit, int i) {
ll res = 0;
while (i > 0) {
res += bit[i];
i -= i & -i;
}
return res;
}
int main() {
scanf("%d %d", &N, &Q);
for (int i = 1; i <= N; i++) {
bit0[i] = 0;
bit1[i] = 0;
}
for (int i = 1; i <= N; i++) {
int a; scanf("%d", &a);
add(bit0, i, a);
}
for (int i = 0; i < Q; i++) {
char cmd[2];
scanf("%s", cmd);
if (cmd[0] == 'C') {
int a, b, c; scanf("%d %d %d", &a, &b, &c);
add(bit0, a, c * (-a + 1));
add(bit1, a, c);
add(bit0, b + 1, c * b);
add(bit1, b + 1, -c);
} else {
int a, b; scanf("%d %d", &a, &b);
ll s0 = sum(bit1, b) * b + sum(bit0, b);
ll s1 = sum(bit1, a - 1) * (a - 1) + sum(bit0, a - 1);
ll res = s0 - s1;
printf("%lld\n", res);
}
}
return 0;
}