一、概述
线段树的适用范围很广,可以在线维护修改以及查询区间上的最值,求和。对于线段树来说,每次更新以及查询的时间复杂度为O(logn)。
二、简单线段树(无pushdown)
1.单点修改,区间查询
#include<bits/stdc++.h>
using namespace std;
#define qio ios::sync_with_stdio(0), cin.tie(0), cout.tie(0);
typedef long long ll;
const int N = 5e5 + 10;
int a[N];
struct node {
int l, r;
ll sum;
} tr[4 * N];
void build(int i, int l, int r) {
tr[i].l = l, tr[i].r = r;
if (l == r) {
tr[i].sum = a[l];
return;
}
int mid = (l + r) >> 1;
build(i * 2, l, mid);
build(i * 2 + 1, mid + 1, r);
tr[i].sum = tr[i * 2].sum + tr[i * 2 + 1].sum;
}
void update(int i, int pos, int value) {
if (tr[i].l == tr[i].r) {
tr[i].sum += value;
return;
}
if (pos <= tr[i * 2].r) update(i * 2, pos, value);
else update(i * 2 + 1, pos, value);
tr[i].sum = tr[i * 2].sum + tr[i * 2 + 1].sum;
}
ll search(int i, int l, int r) {
if (l <= tr[i].l && tr[i].r <= r) return tr[i].sum;
ll res = 0;
if (tr[i * 2].r >= l) res += search(i * 2, l, r);
if (tr[i * 2 + 1].l <= r) res += search(i * 2 + 1, l, r);
return res;
}
int main() {
int n, m; cin >> n >> m;
for (int i = 1; i <= n; i++) cin >> a[i];
build(1, 1, n);
for (int i = 1; i <= m; i++) {
int op; cin >> op;
if (op == 1) {
int pos, value; cin >> pos >> value;
update(1, pos, value);
} else {
int l, r; cin >> l >> r;
cout << search(1, l, r) << '\n';
}
}
}
2.区间修改,单点查询
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N = 5e5 + 10;
int a[N];
struct node{
int l, r;
int num;
}tr[4 * N];
void build(int i, int l, int r){
tr[i] = {l, r, 0};
if(l == r){
tr[i].num = a[l];
return;
}
int mid = (l + r) >> 1;
build(i << 1 , l, mid);
build(i << 1 | 1, mid + 1, r);
}
void update(int i, int l, int r, int x){
if(l <= tr[i].l && tr[i].r <= r){
tr[i].num += x;
return;
}
int mid = (tr[i].l + tr[i].r) >> 1;
if(l <= mid)update(i << 1, l, r, x);
if(mid < r)update(i << 1 | 1, l, r, x);
}
ll query(int i, int pos){
if(tr[i].l == tr[i].r) return tr[i].num;
ll res = 0;
int mid = (tr[i].l + tr[i].r) >> 1;
if(pos <= mid) res += query(i << 1, pos);
else res += query(i << 1 | 1, pos);
return res + tr[i].num;
}
int main() {
int n, m; cin >> n >> m;
for(int i = 1; i <= n; i++) cin >> a[i];
build(1, 1, n);
while(m--){
int op; cin >> op;
if(op == 1){
int l, r, k; cin >> l >> r >> k;
update(1, l, r, k);
} else {
int x; cin >> x;
cout << query(1, x) << '\n';
}
}
}
三、进阶线段树
1.区间修改,区间查询
P3372 【模板】线段树 1 - 洛谷 | 计算机科学教育新生态 (luogu.com.cn)
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N = 1e5 + 10;
int a[N];
struct node{
int l, r;
ll num;
int lz;
}tr[4 * N];
void build(int i, int l, int r){
tr[i] = {l, r, 0, 0};
if(l == r){
tr[i].num = a[l];
return;
}
int mid = (l + r) >> 1;
build(i << 1 , l, mid);
build(i << 1 | 1, mid + 1, r);
tr[i].num = tr[i << 1].num + tr[i << 1 | 1].num;
}
void pushdown(int i){
if(tr[i].lz){
tr[i << 1].lz += tr[i].lz;
tr[i << 1 | 1].lz += tr[i].lz;
int mid = (tr[i].l + tr[i].r) >> 1;
tr[i << 1].num += tr[i].lz * (mid - tr[i << 1].l + 1);
tr[i << 1 | 1].num += tr[i].lz * (tr[i << 1 | 1].r - mid);
tr[i].lz = 0;
}
}
void update(int i, int l, int r, int x){
if(l <= tr[i].l && tr[i].r <= r){
tr[i].num += x * (tr[i].r - tr[i].l + 1);
tr[i].lz += x;
return;
}
pushdown(i);
int mid = (tr[i].l + tr[i].r) >> 1;
if(l <= mid)update(i << 1, l, r, x);
if(mid < r)update(i << 1 | 1, l, r, x);
tr[i].num = tr[i << 1].num + tr[i << 1 | 1].num;
}
ll query(int i, int l, int r){
if(l <= tr[i].l && tr[i].r <= r) return tr[i].num;
pushdown(i);
ll res = 0;
if(l <= tr[i << 1].r) res += query(i << 1, l, r);
if(tr[i << 1 | 1].l <= r) res += query(i << 1 | 1, l, r);
return res;
}
int main() {
int n, m; cin >> n >> m;
for(int i = 1; i <= n; i++) cin >> a[i];
build(1, 1, n);
while(m--){
int op; cin >> op;
if(op == 1){
int l, r, k; cin >> l >> r >> k;
update(1, l, r, k);
} else {
int l, r; cin >> l >> r;
cout << query(1, l, r) << '\n';
}
}
}
2.乘法线段树
P3373 【模板】线段树 2 - 洛谷 | 计算机科学教育新生态 (luogu.com.cn)
先乘后加
#include<bits/stdc++.h>
using namespace std;
#define qio ios::sync_with_stdio(0), cin.tie(0), cout.tie(0);
typedef long long ll;
const int N = 1e5 + 10;
int a[N];
ll mod;
struct node{
int l, r;
ll num;
ll plz, mlz;
}tr[4 * N];
void build(int i, int l, int r){
tr[i] = {l, r, 0, 0, 1};
if(l == r){
tr[i].num = a[l] % mod;
return;
}
int mid = (l + r) >> 1;
build(i << 1 , l, mid);
build(i << 1 | 1, mid + 1, r);
tr[i].num = (tr[i << 1].num + tr[i << 1 | 1].num) % mod;
}
void pushdown(int i){
//下级示数先乘i的mlz标记,再加上i的plz标记
ll mid = (tr[i].l + tr[i].r) >> 1;
tr[i << 1].num = (tr[i].mlz * tr[i << 1].num + ((mid - tr[i].l + 1) * tr[i].plz) % mod) % mod;
tr[i << 1 | 1].num = (tr[i].mlz * tr[i << 1 | 1].num + ((tr[i].r - mid) * tr[i].plz) % mod) % mod;
//下级mlz标记乘以i的mlz标记
tr[i << 1].mlz = (tr[i << 1].mlz * tr[i].mlz) % mod;
tr[i << 1 | 1].mlz = (tr[i << 1 | 1].mlz * tr[i].mlz) % mod;
//下级plz标记先乘i的mlz标记,再加上i的plz标记
tr[i << 1].plz = (tr[i << 1].plz * tr[i].mlz + tr[i].plz) % mod;
tr[i << 1 | 1].plz = (tr[i << 1 | 1].plz * tr[i].mlz + tr[i].plz) % mod;
//i的mlz,plz标记重置
tr[i].mlz = 1, tr[i].plz = 0;
}
void add(int i, int l, int r, int x){
if(l <= tr[i].l && tr[i].r <= r){
tr[i].num = (tr[i].num + x * (tr[i].r - tr[i].l + 1)) % mod;
tr[i].plz = (tr[i].plz + x) % mod;
return;
}
pushdown(i);
int mid = (tr[i].l + tr[i].r) >> 1;
if(l <= mid) add(i << 1, l, r, x);
if(mid < r) add(i << 1 | 1, l, r, x);
tr[i].num = (tr[i << 1].num + tr[i << 1 | 1].num) % mod;
}
void mul(int i, int l, int r, int x){
if(l <= tr[i].l && tr[i].r <= r){
tr[i].num = (tr[i].num * x) % mod;
tr[i].mlz = (tr[i].mlz * x) % mod;
tr[i].plz = (tr[i].plz * x) % mod;
return;
}
pushdown(i);
int mid = (tr[i].l + tr[i].r) >> 1;
if(l <= mid) mul(i << 1, l, r, x);
if(mid < r) mul(i << 1 | 1, l, r, x);
tr[i].num = (tr[i << 1].num + tr[i << 1 | 1].num) % mod;
}
ll query(int i, int l, int r){
if(l <= tr[i].l && tr[i].r <= r) return tr[i].num;
pushdown(i);
ll res = 0;
if(l <= tr[i << 1].r) res = (res + query(i << 1, l, r)) % mod;
if(tr[i << 1 | 1].l <= r) res = (res + query(i << 1 | 1, l, r)) % mod;
return res;
}
int main() {
qio
int n, m; cin >> n >> m;
cin >> mod;
for(int i = 1; i <= n; i++) cin >> a[i];
build(1, 1, n);
while(m--){
int op; cin >> op;
if(op == 1){
int l, r, k; cin >> l >> r >> k;
mul(1, l, r, k);
}
if(op == 2){
int l, r, k; cin >> l >> r >> k;
add(1, l, r, k);
}
if(op == 3) {
int l, r; cin >> l >> r;
cout << query(1, l, r) << '\n';
}
}
}
3.线段树与树链剖分
线段树需要在一个序列上进行,最好是一个数组,那如果给定了一棵树该怎么办?
用dfs遍历这棵树,并把遍历的顺序存到dfn中记录搜索的顺序,并记录in数组和out数组表示开始搜索子树i的左边界和右边界,即之后update或者query的l和r。在dfn序列上build线段树即可。
#include<bits/stdc++.h>
using namespace std;
#define qio ios::sync_with_stdio(0), cin.tie(0), cout.tie(0);
typedef long long ll;
typedef double db;
const int N = 2e5 + 10;
vector<int> g[N];
int a[N];
struct tree{
int l, r;
ll num;
ll tot;
int lz;
}tr[N * 4];
//dfn[]表示dfs的顺序,之后将线段树加在dfn[]上
//in[]表示u节点子树在序列上的左边界,out[]表示u节点子树在序列上的有边界
int in[N], out[N], dfn[N], cnt;
void dfs(int now, int fa){
//开始遍历now节点子树
in[now] = ++cnt, dfn[cnt] = now;
for(auto &nex : g[now]){
if(nex == fa) continue;
dfs(nex, now);
}
//结束遍历now节点子树
out[now] = cnt;
}
void build(int i, int l, int r){
tr[i] = {l, r, 0, 0, 0};
if(l == r){
tr[i].num = a[dfn[l]];
tr[i].tot = 1;
return;
}
int mid = (l + r) >> 1;
build(i << 1, l, mid);
build(i << 1 | 1, mid + 1, r);
tr[i].num = tr[i << 1].num + tr[i << 1 | 1].num;
tr[i].tot = tr[i << 1].tot + tr[i << 1 | 1].tot;
}
void pushdown(int i){
if(tr[i].lz){
//将灯的数量转换
tr[i << 1].num = tr[i << 1].tot - tr[i << 1].num;
tr[i << 1 | 1].num = tr[i << 1 | 1].tot - tr[i << 1 | 1].num;
//lz标记修改
tr[i << 1].lz ^= 1;
tr[i << 1 | 1].lz ^= 1;
tr[i].lz = 0;
}
}
void update(int i, int l, int r){
if(l <= tr[i].l && tr[i].r <= r){
//将灯的数量转换
tr[i].num = tr[i].tot - tr[i].num;
//lz标记修改
//0变1:之前没修改过,现在修改了
//1变0:之前修改过一次了,再修改一次相当于没修改
tr[i].lz ^= 1;
return;
}
pushdown(i);
int mid = (tr[i].l + tr[i].r) >> 1;
if(l <= mid) update(i << 1, l, r);
if(mid < r) update(i << 1 | 1, l, r);
tr[i].num = tr[i << 1].num + tr[i << 1 | 1].num;
tr[i].tot = tr[i << 1].tot + tr[i << 1 | 1].tot;
}
ll query(int i, int l, int r){
if(l <= tr[i].l && tr[i].r <= r) return tr[i].num;
pushdown(i);
ll res = 0;
if(l <= tr[i << 1].r) res += query(i << 1, l, r);
if(tr[i << 1 | 1].l <= r) res += query(i << 1 | 1, l, r);
return res;
}
signed main() {
int n; cin >> n;
for(int i = 2; i <= n; i++){
int x; cin >> x;
g[x].push_back(i);
}
for(int i = 1; i <= n; i++) cin >> a[i];
dfs(1, -1);
// for(int i = 1; i <= cnt; i++){
// cout << dfn[i] << ' ';
// }
// cout << '\n';
// for(int i = 1; i <= n; i++){
// cout << in[i] << ' ' << out[i] << '\n';
// }
build(1, 1, n);
int q; cin >> q;
while(q--){
string s; cin >> s;
int x; cin >> x;
if(s == "pow"){
update(1, in[x], out[x]);
}else{
cout << query(1, in[x], out[x]) << '\n';
}
}
}
四、线段树引申
1.扫描线&&矩形面积井
P5490 【模板】扫描线&&矩形面积并 - 洛谷 | 计算机科学教育新生态 (luogu.com.cn)
分析详见题解,太难了阿巴阿巴
【学习笔记】扫描线 - 洛谷专栏 (luogu.com.cn)
#include<bits/stdc++.h>
using namespace std;
#define qio ios::sync_with_stdio(0), cin.tie(0), cout.tie(0);
typedef long long ll;
#define int ll
const int N = 1e6 + 10;
struct Scanline{
ll l, r, y;
//mark标记该线段是入边1还是出边-1
int mark;
bool operator < (const Scanline &t) const{
return y < t.y;
}
}line[2 * N];
struct Segtree{
int l, r;
//num记录该节点,即该划分线段是否还有贡献,num = 0时无贡献
int num;
//这个线段区间覆盖的长度
ll len;
}tr[4 * N];
int X[2 * N];
void build(int i, int l, int r){
tr[i] = {l, r, 0, 0};
if(l == r) return;
int mid = (l + r) >> 1;
build(i << 1, l, mid);
build(i << 1 | 1, mid + 1, r);
}
void pushup(int i){
//如果有贡献
if(tr[i].num){
//tr[i].r + 1处理,因为线段树节点为线段,l,r表示的是线段的第1234条
//所以对于一个叶子节点来说,tr[i].r = tr[i].l,则X也相等
//而X表示的是坐标,叶子节点线段的真实长度应该为X[tr[i].r + 1] - X[tr[i].l]
tr[i].len = X[tr[i].r + 1] - X[tr[i].l];
}else{
//无贡献则上传左右儿子的结果
tr[i].len = tr[i << 1].len + tr[i << 1 | 1].len;
}
}
void update(int i, int l, int r, int mark){
if(X[tr[i].r + 1] <= l || r <= X[tr[i].l]) return;
if(l <= X[tr[i].l] && X[tr[i].r + 1] <= r){
tr[i].num += mark;
pushup(i);
return;
}
update(i << 1, l, r, mark);
update(i << 1 | 1, l, r, mark);
pushup(i);
}
void solve() {
int n; cin >> n;
for(int i = 1; i <= n; i++){
int x1, y1, x2, y2; cin >> x1 >> y1 >> x2 >> y2;
//离散化,记录扫描线
X[(i << 1) - 1] = x1, X[i << 1] = x2;
line[(i << 1) - 1] = {x1, x2, y1, 1};
line[i << 1] = {x1, x2, y2, -1};
}
n <<= 1;
//扫描线按y从小到大排序
sort(line + 1, line + n + 1);
//X坐标离散化从小到大排序,cnt记录去重后X的个数
sort(X + 1, X + 1 + n);
int cnt = unique(X + 1, X + 1 + n) - X - 1;
//线段的数量一共有cnt - 1条
build(1, 1, cnt - 1);
ll ans = 0;
for(int i = 1; i <= n - 1; i++){
update(1, line[i].l, line[i].r, line[i].mark);
ans += tr[1].len * (line[i + 1].y - line[i].y);
}
cout << ans << '\n';
}
signed main() {
qio
int T = 1;
// cin >> T;
while (T--)solve();
}