初学线段树,刚刚学会了一个模板,正好分享一下,带有注释。
上代码:
#include<iostream>
#include<stdio.h>
#include<cstring>
#include<cstdlib>
#include<cmath>
#include<vector>
#include<algorithm>
#include<stack>
#include<queue>
#include<deque>
#include <iomanip>
#include<sstream>
#include<numeric>
#include<map>
#include<limits.h>
#include<unordered_map>
#include<set>
#define int long long
#define MAX 100010
#define inf 0x3f3f3f3f
#define _for(i,a,b) for(int i=a;i<(b);i++)
using namespace std;
typedef pair<int, int> PII;
int n, m;
int counts;
int dx[] = { 0,1,0,-1};
int dy[] = { 1,0,-1,0 };
int arr[MAX];
int w[MAX];
struct node {
int l, r;
int sum, add;
}tr[MAX*4+7];
void pushup(int p) {
tr[p].sum = tr[2 * p].sum + tr[2 * p + 1].sum;
}
void pushdown(int p) {
if (tr[p].add) {
tr[2 * p].sum += (tr[2 * p].r - tr[2 * p].l + 1) * tr[p].add;//整个区间的数都增加了add,那么我们只需要将原来的和加上儿子们的增量就行了。
tr[2 * p + 1].sum += (tr[2 * p + 1].r - tr[2 * p + 1].l + 1) * tr[p].add;//同理
tr[2 * p].add += tr[p].add;//继承这个区间的加数
tr[2 * p + 1].add += tr[p].add;//同理
tr[p].add = 0;//既然已经继承到了下面,也就没有增量了。
}
}//用于区间修改的操作
void build(int p,int l, int r) {
tr[p] = { l,r,w[l],0 };//初始化
if (l == r) {
tr[p].sum = w[l];//这个w数组其实是用来输入数据的,也就是做题目的时候需要用到的输入数组。
return;
}//如果说这个时候的线段树区间是1个数,我们就返回
int mid = l + r >> 1;//不是,我们就开始平分区间
build(2 * p, l, mid);//在左边的区间开始构建
build(2 * p + 1, mid + 1, r);//在右边的区间开始构建,形似二叉树
pushup(p);//在下面递归之后,在向上回溯的时候需要进行区间和的合并,所以这里是对于区间和的记录。
}//构建线段树
void update(int p, int x, int y, int k) {
if (x <= tr[p].l && y>= tr[p].r) {//整个区间完全涵盖了我们想要查找的区间,我们就直接返回这里的和就行。
tr[p].sum += (tr[p].r - tr[p].l + 1) * k;//和上面的pushdown解释一样。
tr[p].add += k;//增量
return;
}
int mid = tr[p].l + tr[p].r >> 1;//如果没有完全涵盖还是需要分裂区间
pushdown(p);//向下更新区间和
if (x <= mid) update(2 * p, x, y, k);//查找区间的左边界小于中间的值,分裂到左区间
if (y > mid) update(2 * p + 1, x, y, k);//同理,分裂到右区间
pushup(p);//之后一块向上更新区间和
}
int quary(int p, int x, int y) {
if (x <= tr[p].l && y >= tr[p].r)
return tr[p].sum;
int mid = tr[p].r + tr[p].l >> 1;
pushdown(p);
int sum = 0;
if (x <= mid) sum += quary(2 * p, x, y);
if (y > mid) sum += quary(2 * p + 1, x, y);
return sum;
}
signed main() {
ios::sync_with_stdio(false);
cin.tie(NULL); cout.tie(NULL);
cin >> n >> m;
_for(i, 1, n + 1)
cin >> w[i];
build(1, 1, n);
while (m--) {
int bian;
cin >> bian;
if (bian == 2) {
int x, y;
cin >> x >> y;
cout << quary(1, x, y) << endl;
}
else if (bian == 1) {
int x, y, k;
cin >> x >> y >> k;
update(1, x, y, k);
}
}
return 0;
}