虽然看着挺长,但是其中有很多注释,同时也可作为洛谷P3372的题解。
从第10行看即可。
/// @file 线段树模板
/// 默认下标从0开始, @see StartAt
// 如果很闲,可以包装成模板类打发时间
#include <iostream>
#if __cplusplus < 201103L
typedef unsigned long uint32_t;
#endif
typedef long long ll;
const int StartAt = 1;
const int N = 100001;
int *val; // values
struct /* 不知道取啥名 */
{
int l,r;
ll sum,add;
}a[N*4];
/** @brief 构建线段树
* @note 使用方法:build(1,n), 其中n表示数组val的长度
*/
void build(const int l, const int r, const uint32_t i=StartAt) // i: 当前点编号,下同
{
a[i].l = l,
a[i].r = r;
if (l==r){
a[i].sum = val[l];
return;
}
int mid = (l + r) >> 1;
build(l, mid, i<<1);
build(mid+1, r, i<<1|1);
a[i].sum = a[i<<1].sum + a[i<<1|1].sum;
}
/**
* @warning 在修改和查询时,只要进入子节点,就一定要`down`
*/
inline void down(const uint32_t i)
{
if (a[i].add){
ll &x = a[i].add;
a[i<<1].add += x;
a[i<<1|1].add += x;
a[i<<1].sum += (a[i<<1].r -a[i<<1].l +1) * x;
a[i<<1|1].sum += (a[i<<1|1].r -a[i<<1|1].l +1) * x;
a[i].add = 0; // 一定要清零
}
}
inline void up(const uint32_t i)
{
a[i].sum = 0;
// 我也不知道两个if有什么用
if (a[i<<1].l)
a[i].sum += a[i<<1].sum;
if (a[i<<1|1].l)
a[i].sum += a[i<<1|1].sum;
}
/** @brief 单点修改
* @param index 修改的位置
* @param v 变化量
*/
void update(const int index, const ll v, const uint32_t i=StartAt)
{
if (a[i].l == a[i].r){
a[i].sum += v;
return;
}
down(i);
int mid = (a[i].l + a[i].r) >> 1;
if (index<=mid)
update(index, v, i<<1);
else
update(index, v, i<<1|1);
up(i);
}
/** @brief 区间修改
* @param l 左端点(起始位置)
* @param r 右端点(终止位置)
* @param add 变化量
* @warning 调用此函数时,“可选参数”i不能省略;
* 若省略,会匹配三参的单点修改update。
*/
void updata(const int l, const int r, const ll add, const uint32_t i=StartAt)
{
if (a[i].l==l && a[i].r==r){
a[i].sum += (a[i].r - a[i].l +1) * add;
a[i].add += add;
return;
}
down(i);
int mid = (a[i].l + a[i].r) >> 1;
if (mid >= r)
updata(l, r, add, i<<1);
else if (mid < l)
updata(l, r, add, i<<1|1);
else
updata(l, mid, add, i<<1),
updata(mid+1, r, add, i<<1|1);
up(i);
}
/** @brief 区间查询
* @param l 所查询区间的左端点
* @param r 所查询区间的右端点
*/
ll query(const int l, const int r, const uint32_t i=StartAt)
{
if (a[i].l==l && a[i].r==r)
return a[i].sum;
down(i);
int mid = (a[i].l + a[i].r) >> 1;
if (mid >= r)
return query(l, r, i<<1);
else if (mid < l)
return query(l, r, i<<1|1);
return query(l, mid, i<<1)
+query(mid+1, r, i<<1|1);
}
int main()
{
using namespace std;
int n,m, c,x,y;
ll k;
cin >> n >> m;
val = new int[n+2];
for (int i=1; i<=n; ++i)
cin >> val[i];
build(1,n);
delete val;
while (m--){
cin >> c >> x >> y;
if (c==1){
cin >> k;
updata(x, y, k, 1);
}
else{
cout << query(x, y) << endl;
}
}
}