线段树
细节要点
线段树: 建立一棵二叉树,用递归建图
- 在叶子节点处无需下放懒惰标记,所以懒惰标记可以不下传到叶子节点。
- 下放懒惰标记可以写一个专门的函数 pushdown,从儿子节点更新当前节点也可以写一个专门的函数 maintain(或者对称地用 pushup),降低代码编写难度。
- 标记永久化,如果确定懒惰标记不会在中途被加到溢出(即超过了该类型数据所能表示的最大范围),那么就可以将标记永久化。标记永久化可以避免下传懒惰标记,只需在进行询问时把标记的影响加到答案当中,从而降低程序常数。具体如何处理与题目特性相关,需结合题目来写。这也是树套树和可持久化数据结构中会用到的一种技巧。
模板
递归建树
// 递归建树
void build(int s, int t, int p) {
// 对 [s,t] 区间建立线段树,当前根的编号为 p
if(s == t) { // 递归到单个节点时记录
d[p] = a[s];
return;
}
int mid = s + ((t-s) >> 1); // 如果写成 (s + t) >> 1 可能会时间超限
build(s, mid, p*2), build(mid+1, t, p*2+1);
// 递归对左右区间建树
d[p] = d[p*2] + d[(p*2)+1];
}
区间求和
// 求区间和
int getsum(int l, int r, int s, int t, int p) {
// [l,r] 为查询区间,[s,t] 为当前节点管理的区间,p 为当前节点的编号
if(l <= s && t <= r) return d[p];// 当前区间为询问区间的子集时直接返回当前区间的和
int mid = s + ((t-s) >> 1), sum = 0;
if(l <= mid) sum += getsum(l, r, s, mid, p*2);
// 如果左儿子代表的区间 [l,m] 与询问区间有交集,则递归查询左儿子,一直到全部找到
if(r > mid) sum += getsum(l, r, mid+1, t, p*2+1);
// 如果右儿子代表的区间 [m+1,r] 与询问区间有交集,则递归查询右儿子
return sum;
}
区间修改(区间加上某个值)
// 区间修改(区间加上某个值)
void update(int l, int r, int c, int s, int t, int p) {
// [l,r] 为修改区间,c 为被修改的元素的变化量,[s,t] 为当前节点管理的区间,p为当前节点的编号
if(l <= s && t <= r) {
d[p] += (t-s+1) * c; // 区间内每个元素+c, p变化 len*c
lag[p] += c;
return ;
} // 当前区间为修改区间的子集时直接修改当前节点的值,然后打标记,结束修改
int mid = s + ((t-s) >> 1);
if(lag[p] && s!=t) { // 如果当前节点的懒标记非空,则更新当前节点两个子节点的值和懒标记值
d[p*2] += lag[p] * (mid-s+1), d[p*2+1] += lag[p] * (t-mid);
lag[p*2] += lag[p], lag[p*2+1] += lag[p]; // 将标记下传给子节点
lag[p] = 0; // 清空当前节点的标记
}
// 到达此处,说明修改区间小于管辖区间,需要向下细分区间去找
if (l <= mid) update(l, r, c, s, mid, p*2);
if (r > mid) update(l, r, c, mid+1, t, p*2+1);
d[p] = d[p*2] + d[p*2+1];
}
区间修改(区间求和)
// 区间求和
int getsum(int l, int r, int s, int t, int p) {
// [l,r] 为查询区间,[s,t] 为当前节点包含的区间,p为当前节点的编号
if(l <= s && t <= r) return d[p];// 当前区间为询问区间的子集时直接返回当前区间的和
int mid = s + ((t-s) >> 1);
if(lag[p]) {
d[p*2] += lag[p] * (mid-s+1), d[p*2+1] += lag[p] * (t-mid);
lag[p*2] += lag[p], lag[p*2+1] += lag[p];
lag[p] = 0;
}
int sum = 0;
if(l <= mid) sum += getsum(l, r, s, mid, p*2);
if(r > mid) sum += getsum(l, r, mid, t, p*2+1);
return sum;
}
区间修改(将区间修改为某一特定值)
// 将区间修改为某一特定值而不是加上
void pushdown() { // 懒人标记下放
d[p*2] += lag[p] * (mid-s+1), d[p*2+1] += lag[p] * (t-mid);
lag[p*2] += lag[p]; lag[p*2+1] += lag[p];
lag[p] = 0;
}
void update(int l, int r, int c, int s, int t, int p) {
if(l <= s && t <= r) {
d[p] = (t-s+1) * c;
lag[p] = c;
return ;
}
int mid = s + ((t-s) >> 1);
if(lag[p]) pushdown(); // 懒人标记下放
if(l <= mid) update(l, r, c, s, mid, p*2);
if(r > mid) update(l, r, c, mid, t, p*2+1);
}
int getsum(int l, int r, int s, int t, int p) {
if(l <= s && t <= r) return d[p];
int mid = s + ((t-s) >> 1);
if (lag[p]) pushdown();
int sum = 0;
if(l <= mid) sum = getsum(l, r, s, mid, p*2); // 每次sum都是从0开始的,所以 += 和 = 都可
if(r > mid) sum += getsum(l, r, mid+1, t, p*2+1);
return sum;
}
#include<bits/stdc++.h>
#define ll long long
using namespace std;
const int maxn = 1e5+10;
typedef unsigned long long ULL;
ll lag[maxn*4], d[maxn*4];
ll a[maxn];
void Build(ll s, ll t, ll p) {
if(s == t) {
d[p] = a[s]; // p是树的编号
return ;
}
ll mid = s + ((t-s)>>1);
Build(s, mid, p*2);
Build(mid+1, t, p*2+1);
d[p] = d[p*2] + d[p*2+1];
}
void pushdown(ll p, ll s, ll t) {
ll mid = s + ((t-s) >> 1);
d[p*2] += lag[p]*(mid-s+1); d[p*2+1] += lag[p]*(t-mid); // 注意要把区间算全
lag[p*2] += lag[p]; lag[p*2+1] += lag[p]; // 每次下lag要累计
lag[p] = 0;
}
void add(ll l, ll r, ll c, ll s, ll t, ll p) {
if(l <= s && t <= r) {
d[p] += c*(t-s+1);
lag[p] += c;
return;
}
ll mid = s + ((t-s) >> 1);
if(lag[p] && s!=t) pushdown(p, s, t);
if(l <= mid) add(l, r, c, s, mid, p*2);
if(r > mid) add(l, r, c, mid+1, t, p*2+1);
d[p] = d[p*2] + d[p*2+1];
}
ll getsum(ll l, ll r, ll s, ll t, ll p) {
if(l <= s && r >= t) return d[p];
int mid = s + ((t-s) >> 1);
if(lag[p]) pushdown(p, s, t);
ll sum = 0;
if(l <= mid) sum = getsum(l, r, s, mid, p*2);
if(r > mid) sum += getsum(l, r, mid+1, t, p*2+1);
return sum;
}
int main() {
// freopen("test.in", "r", stdin);
ll n, m;
scanf("%lld%lld", &n, &m);
for(int i = 1; i <= n; i++) scanf("%lld", &a[i]);
Build(1, n, 1);
while(m--) {
int op;
scanf("%d", &op);
if(op == 1) { // 区间加
ll l, r, k;
scanf("%lld%lld%lld", &l, &r, &k);
add(l, r, k, 1, n, 1);
}
else { // 区间求和
ll l, r;
scanf("%lld%lld", &l, &r);
printf("%lld\n", getsum(l, r, 1, n, 1));
}
}
return 0;
}
实现区间乘和区间加
debug 易错:
- 每次运算完记得 laz 下放
- 加法运算的时候 记得+laz*(区间长)
#include<cstdio>
#include<iostream>
#define ll long long
using namespace std;
const int maxn = 1e5+10;
inline ll read() {
ll x = 0, f = 1;
char c = getchar();
while(c < '0' || c > '9') {
if(c == '-') f = -1;
c = getchar();
}
while(c >= '0' && c <= '9') {
x = x*10+c-'0';
c = getchar();
}
return x*f;
}
ll n, m, mod;
ll a[maxn], sum[maxn*4], mul[maxn*4], laz[maxn*4];
void up(int i) {
sum[i] = (sum[i*2] + sum[i*2+1])%mod;
}
void pd(int i, int s, int t) {
int l = (i*2), r = (i*2+1), mid = s+t>>1; // mid是区间中点,不是 l , r
if(mul[i] != 1) { // 懒标记传递,两个懒标记 ,下放时,先下放乘,因为要作用于其整体,所以既要修改mul,也要修改laz
mul[l] *= mul[i]; mul[l] %= mod;
mul[r] *= mul[i]; mul[r] %= mod;
laz[l] *= mul[i]; laz[l] %= mod;
laz[r] *= mul[i]; laz[r] %= mod;
sum[l] *= mul[i]; sum[l] %= mod;
sum[r] *= mul[i]; sum[r] %= mod;
mul[i] = 1;
}
if(laz[i] != 0) {
sum[l] += laz[i]*(mid-s+1); sum[l] %= mod;
sum[r] += laz[i]*(t-mid); sum[r] %= mod;
laz[l] += laz[i]; laz[l] %= mod;
laz[r] += laz[i]; laz[r] %= mod;
laz[i] = 0;
}
return ;
}
void build(int s, int t, int i) {
mul[i] = 1;
laz[i] = 0;
if(s == t) {
sum[i] = a[s];
return ;
}
int mid = s+((t-s) >> 1);
build(s, mid, i*2);
build(mid+1, t, i*2+1);
up(i);
}
void multi(int l, int r, int s, int t, ll k, int i) {
int mid = s+((t-s) >> 1);
if(l <= s && t <= r) {
sum[i] *= k; sum[i] %= mod;
mul[i] *= k; mul[i] %= mod;
laz[i] *= k; laz[i] %= mod;
return ;
}
pd(i, s, t);
if(l <= mid) multi(l, r, s, mid, k, i*2);
if(r > mid) multi(l, r, mid+1, t, k, i*2+1);
up(i);
}
void add(int l, int r, int s, int t, ll k, int i) {
int mid = s+((t-s) >> 1);
if(l <= s && r >= t) {
sum[i] += k*(t-s+1); sum[i] %= mod; // 一定注意加的时候整个区间每个数都要加
laz[i] += k; laz[i] %= mod;
return ;
}
pd(i, s, t); // !!!!!!记得每一步预算都要下放laz,WA哭了
if(l <= mid) add(l, r, s, mid, k, i*2);
if(r > mid) add(l, r, mid+1, t, k, i*2+1);
up(i);
}
ll getsum(int l, int r, int s, int t, int i) {
int mid = s+((t-s) >> 1);
if(l <= s && t <= r) return sum[i] % mod;
pd(i, s, t); // 求和之前记得下放 laz
int sum = 0;
if(l <= mid) sum = getsum(l, r, s, mid, i*2);
if(mid < r) sum += getsum(l, r, mid+1, t, i*2+1);
return sum % mod;
}
int main() {
// freopen("test.in", "r", stdin);
n = read(); m = read(); mod = read();
// cout << "mod=" << mod << endl;
for(int i = 1; i <= n; i++) a[i] = read();
build(1, n, 1);
while(m--) {
int op, l, r;
scanf("%d%d%d", &op, &l, &r);
if(op == 1) { // multi
ll k = read();
multi(l, r, 1, n, k, 1);
}
else if(op == 2) { // add
ll k = read();
add(l, r, 1, n, k, 1);
}
else { // sum
printf("%lld\n", getsum(l, r, 1, n, 1));
}
}
return 0;
}
将区间修改为同一个值
关键代码部分
void up(i) {
sum[i] = sum[i*2] + sum[i*2+1];
}
void pd(int i, int s, int t) {
int l = i*2, r = i*2+1, mid = s+t>>1;
sum[l] += laz[i]*(mid-s+1);
sum[r] += laz[i]*(t-mid);
laz[i*2] += laz[i];
laz[i*2+1] += laz[i];
return ;
}
void build(int s, int t, int i) {
if(s == t) {
sum[i] = a[s];
return ;
}
int mid = l+r>>1;
build(s, mid, i*2);
build(mid+1, r, i*2+1);
up(i);
}
void update(int l, int r, int s, int t, int k, int i){
if(l <= s && r >= k) {
laz[i] = k; // 因为递归调用该函数,所以最终会把整个[l,r]修改为 k
sum[i] = k*(t-s+1);
return ;
}
pd(i, s, t);
int mid = s+t>>1;
if(l <= mid) update(l, r, s, mid, k, i*2);
if(r > mid) update(l, r, mid+1, t, k, i*2+1);
}
ll getsum(int l, int r, int s, int t, int i) {
if(l <= s && r >= t) return sum[i];
int mid = s+t>>1;
pd(i, s, t);
ll sum = 0;
if(l <= mid) sum = getsum(l, r, s, mid, i*2);
if(r > mid) sum += getsum(l, r, mid+1, t, i*2+1);
return sum;
}