线段树
简介
线段树是一棵二叉树,树中的每一个结点表示了一个区间[a,b]。每一个叶子节点表示了一个单位区间。对于每一个非叶结点所表示的结点[a,b],其左儿子表示的区间为[a,(a+b)/2],右儿子表示的区间为[(a+b)/2,b]。线段树需要注意的几点
1、线段树空间一般要开到数据规模的4~5倍(不卡数据尽可能多开)
2、lson、rson几个变量尽可能位运算吧(省时),想偷懒就用#define 好了
3、build函数(下文会介绍)中,if (l == r) sum[n] = a[l] 或sum[n] = a[r]注意不是sum[n] = a[n]
4、下文几个常用函数中,标记‘*’的可根据需要修改- 几个线段树的常用函数
//build(建树用)函数的伪代码 /
build (int n, int l, int r){
size[n] = r - l + 1;
if (l == r){
sum[n] = a[l]; return ;
}
int mid = (l + r) >> 1;
build (n << 1, l, mid);
build (n << 1 | 1, mid + 1, r);
sum[n] = sum[n << 1] + sum[n << 1 | 1];//*
}
//updata (区间修改)伪代码/
down (int n){
sum[n << 1] += lazy[n] * size[n << 1];
sum[n << 1 | 1] += lazy[n] * size[n << 1 | 1];//*
/上述两步是区间求和的代码
求最大值应为
sum[n << 1] += lazy[n];
sum[n << 1 | 1] += lazy[n];/
lazy[n << 1] += lazy[n], lazy[n << 1 | 1] += lazy[n], lazy[n] = 0;
}
updata (int n, int l, int r, int xl, int xr, int v){
if (l != r) down(n);/即在解决单点修改问题 --> 单点不需要down/
if (l == xl && r == xr){
sum[n] += v * size[n], lazy[n] += v;//*
return ;
}
int mid = (l + r) >> 1;
if (xr <= mid) updata(n << 1, l, mid, xl, xr, v);
else if (xl > mid) updata(n << 1 | 1, mid + 1, r, xl, xr, v);
else updata(n << 1, l, mid, xl, mid, v), updata(n << 1 | 1, mid + 1, r, mid + 1, xr, v);
sum[n] = sum[n << 1] + sum[n << 1 | 1];//*
}
//query (区间查询“sum”函数)伪代码 /
int query (int n, int l, int r, int xl, int xr){
if (l != r) down(n);
if (xl <= l && r <= xr) return sum[n];
int mid = (l + r) >> 1;
if (xr <= mid) return query(n << 1, l, mid, xl, xr);
else if (xl > mid) return query(n << 1 | 1, mid + 1, r, xl, xr);
else return query(n << 1, l, mid, xl, mid) + query(n << 1 | 1, mid + 1, r, mid + 1, xr);
}
- 下面看几道例题(题目差不多)
Codevs 1080线段树练习
#include <iostream>
#include <cstdio>
#include <cstdlib>
#define ls (n << 1)
#define rs (n << 1 | 1)
using namespace std;
int num, m, bj;
int a[1000100];
int size[1000100];
int lazy[1000100];
int sum[1000100];
int n, v, xl, xr;
void build(int n, int l, int r){
size[n] = r - l + 1;
if (l == r){
sum[n] = a[l];
return ;
}
int mid = (l + r) >> 1;
build(ls, l, mid), build(rs, mid + 1, r);
sum[n] = sum[ls] + sum[rs];
}
void down(int n){
sum[ls] += lazy[n] * size[ls];
sum[rs] += lazy[n] * size[rs];
lazy[ls] += lazy[n], lazy[rs] += lazy[n], lazy[n] = 0;
}
void add(int n, int l, int r, int xl, int xr, int v){
if (l != r) down(n);
if (xl == l && r == xr){
sum[n] += v * size[n];
lazy[n] += v;
return ;
}
int mid = (l + r) >> 1;
if (xr <= mid) add(ls, l, mid, xl, xr, v);
else if (xl > mid) add(rs, mid + 1, r, xl, xr, v);
else add(ls, l, mid, xl, mid, v), add(rs, mid + 1, r, mid + 1, xr, v);
sum[n] = sum[ls] + sum[rs];
}
int query(int n, int l, int r, int xl, int xr){
if (l != r) down(n);
if (xl <= l && r <= xr) return sum[n];
int mid = (l + r) >> 1;
if (xr <= mid) return query(ls, l, mid, xl, xr);
else if (xl > mid) return query(rs, mid + 1, r, xl, xr);
else return query(ls, l, mid, xl, mid) + query(rs, mid + 1, r, mid + 1, xr);
}
int main()
{
scanf("%d", &num);
for (int i = 1; i <= num; ++i) scanf("%d", &a[i]);
build(1, 1, num);
scanf("%d", &m);
for (int i = 1; i <= m; ++i){
scanf("%d", &bj);
if (bj == 1) scanf("%d %d", &n, &v), add(1, 1, num, n, n, v);
else scanf("%d %d", &xl, &xr), printf("%d\n", query(1, 1, num, xl, xr));
}
return 0;
}
#include <iostream>
#include <cstdio>
#include <cstdlib>
#define ls (n << 1)
#define rs (n << 1 | 1)
#define L 500000
using namespace std;
int num, m, bj;
int a[L];
int size[L];
int lazy[L];
int sum[L];
int n, v, l, r, j;
int gi(){
char cj = getchar();
int ans = 0;
while (cj > '9' || cj < '0') cj = getchar();
while (cj >= '0' && cj <= '9') ans = ans * 10 + cj - '0', cj = getchar();
return ans;
}
void build(int n, int l, int r){
size[n] = r - l + 1;
if (l == r){
sum[n] = a[l];
return ;
}
int mid = (l + r) >> 1;
build(ls, l, mid), build(rs, mid + 1, r);
sum[n] = sum[ls] + sum[rs];
}
void down(int n){
sum[ls] += lazy[n] * size[ls];
sum[rs] += lazy[n] * size[rs];
lazy[ls] += lazy[n], lazy[rs] += lazy[n], lazy[n] = 0;
}
void add(int n, int l, int r, int xl, int xr, int v){
if (l != r) down(n);
if (xl == l && r == xr){
sum[n] += v * size[n];
lazy[n] += v;
return ;
}
int mid = (l + r) >> 1;
if (xr <= mid) add(ls, l, mid, xl, xr, v);
else if (xl > mid) add(rs, mid + 1, r, xl, xr, v);
else add(ls, l, mid, xl, mid, v), add(rs, mid + 1, r, mid + 1, xr, v);
sum[n] = sum[ls] + sum[rs];
}
int query(int n, int l, int r, int xl, int xr){
if (l != r) down(n);
if (xl <= l && r <= xr) return sum[n];
int mid = (l + r) >> 1;
if (xr <= mid) return query(ls, l, mid, xl, xr);
else if (xl > mid) return query(rs, mid + 1, r, xl, xr);
else return query(ls, l, mid, xl, mid) + query(rs, mid + 1, r, mid + 1, xr);
}
int main()
{
num = gi();
for (int i = 1; i <= num; ++i) a[i] = gi();
build(1, 1, num);
m = gi();
for (int i = 1; i <= m; ++i){
scanf("%d", &bj);
if (bj == 1) scanf("%d %d %d", &l, &r, &v), add(1, 1, num, l, r, v);
else scanf("%d", &j), printf("%d\n", query(1, 1, num, j, j));
}
return 0;
}
#include <iostream>
#include <cstdio>
#include <cstdlib>
#define ls (n << 1)
#define rs (n << 1 | 1)
#define L 1000000
using namespace std;
long long num, m, bj;
long long a[L];
long long size[L];
long long lazy[L];
long long sum[L];
long long n, v, l, r, j, k;
long long gi(){
char cj = getchar();
long long ans = 0;
while (cj > '9' || cj < '0') cj = getchar();
while (cj >= '0' && cj <= '9') ans = ans * 10 + cj - '0', cj = getchar();
return ans;
}
void build(long long n, long long l, long long r){
size[n] = r - l + 1;
if (l == r){
sum[n] = a[l];
return ;
}
long long mid = (l + r) >> 1;
build(ls, l, mid), build(rs, mid + 1, r);
sum[n] = sum[ls] + sum[rs];
}
void down(long long n){
sum[ls] += lazy[n] * size[ls];
sum[rs] += lazy[n] * size[rs];
lazy[ls] += lazy[n], lazy[rs] += lazy[n], lazy[n] = 0;
}
void add(long long n, long long l, long long r, long long xl, long long xr, long long v){
if (l != r) down(n);
if (xl == l && r == xr){
sum[n] += v * size[n];
lazy[n] += v;
return ;
}
long long mid = (l + r) >> 1;
if (xr <= mid) add(ls, l, mid, xl, xr, v);
else if (xl > mid) add(rs, mid + 1, r, xl, xr, v);
else add(ls, l, mid, xl, mid, v), add(rs, mid + 1, r, mid + 1, xr, v);
sum[n] = sum[ls] + sum[rs];
}
long long query(long long n, long long l, long long r, long long xl, long long xr){
if (l != r) down(n);
if (xl <= l && r <= xr) return sum[n];
long long mid = (l + r) >> 1;
if (xr <= mid) return query(ls, l, mid, xl, xr);
else if (xl > mid) return query(rs, mid + 1, r, xl, xr);
else return query(ls, l, mid, xl, mid) + query(rs, mid + 1, r, mid + 1, xr);
}
int main()
{
num = gi();
for (int i = 1; i <= num; ++i) a[i] = gi();
build(1, 1, num);
m = gi();
for (int i = 1; i <= m; ++i){
scanf("%lld", &bj);
if (bj == 1) scanf("%lld %lld %lld", &l, &r, &v), add(1, 1, num, l, r, v);
else scanf("%lld %lld", &j, &k), printf("%lld\n", query(1, 1, num, j, k));
}
return 0;
}
线段树求和板子题
#include <iostream>
#include <cstdio>
#include <cstdlib>
#define ls (n << 1)
#define rs (n << 1 | 1)
#define L 5000000
using namespace std;
long long num, m;
long long a[L];
long long size[L];
long long lazy[L];
long long sum[L];
long long n, v, l, r, j, k;
char bj;
void build(long long n, long long l, long long r){
size[n] = r - l + 1;
if (l == r){
sum[n] = a[l];
return ;
}
long long mid = (l + r) >> 1;
build(ls, l, mid), build(rs, mid + 1, r);
sum[n] = sum[ls] + sum[rs];
}
void down(long long n){
sum[ls] += lazy[n] * size[ls];
sum[rs] += lazy[n] * size[rs];
lazy[ls] += lazy[n], lazy[rs] += lazy[n], lazy[n] = 0;
}
void add(long long n, long long l, long long r, long long xl, long long xr, long long v){
if (l != r) down(n);
if (xl == l && r == xr){
sum[n] += v * size[n];
lazy[n] += v;
return ;
}
long long mid = (l + r) >> 1;
if (xr <= mid) add(ls, l, mid, xl, xr, v);
else if (xl > mid) add(rs, mid + 1, r, xl, xr, v);
else add(ls, l, mid, xl, mid, v), add(rs, mid + 1, r, mid + 1, xr, v);
sum[n] = sum[ls] + sum[rs];
}
long long query(long long n, long long l, long long r, long long xl, long long xr){
if (l != r) down(n);
if (xl <= l && r <= xr) return sum[n];
long long mid = (l + r) >> 1;
if (xr <= mid) return query(ls, l, mid, xl, xr);
else if (xl > mid) return query(rs, mid + 1, r, xl, xr);
else return query(ls, l, mid, xl, mid) + query(rs, mid + 1, r, mid + 1, xr);
}
int main()
{
scanf("%d %d", &num, &m);
for (int i = 0; i < num; ++i) scanf("%d", &a[i]);
build(1, 0, num - 1);
for (int i = 1; i <= m; ++i){
cin >> bj;
if (bj == 'C') scanf("%lld %lld %lld", &l, &r, &v), add(1, 0, num - 1, l, r, v);
if (bj == 'Q') scanf("%lld %lld", &j, &k), printf("%lld\n", query(1, 0, num - 1, j, k));
}
return 0;
}
线段树求最大值的板子题
#include <iostream>
#include <cstdio>
#include <cstdlib>
#define LL long long
#define ls (n << 1)
#define rs (n << 1 | 1)
#define L 1000000
using namespace std;
LL t, l, x;
LL m, d;
char bj;
LL num[L];
void buildtree(int n, int l, int r) {
if (l == r) {
num[n] = 0;
return ;
}
int mid = (l + r) >> 1;
buildtree(ls, l, mid), buildtree(rs, mid + 1, r);
num[n] = max(num[ls], num[rs]);
}
LL query(int n, int l, int r, int a, int b) {
if (a == l && r == b) return num[n];
int mid = (l + r) >> 1;
if (b <= mid) return query(ls, l, mid, a, b);
else if (mid < a) return query(rs, mid + 1, r, a, b);
else return max(query(ls, l, mid, a, mid), query(rs, mid + 1, r, mid + 1, b));
}
void add(int n, int l, int r, int a, LL v) {
if (l == a && r == a) {
num[n] = v;
return ;
}
int mid = (l + r) >> 1;
if (a <= mid) add(ls, l, mid, a, v);
else add(rs, mid + 1, r, a, v);
num[n] = max(num[ls], num[rs]);
}
int main() {
scanf("%lld %lld", &m, &d);
buildtree(1, 1, 210000);
for (int i = 1; i <= m; ++i) {
cin >> bj >> x;
if (bj == 'Q') {
t = query(1, 1, m, l - x + 1, l);
printf("%lld", t);
}
if (bj == 'A') add(1, 1, m, ++l, (t + x) % d);
}
return 0;
}
线段树区间修改板子题
#include <iostream>
#include <cstdio>
#include <cstdlib>
#define L 1000000
#define ls (n << 1)
#define rs (n << 1 | 1)
using namespace std;
int l, n;
int m, a, t;
int x[L], cnt[L], len[L], rc[L], lc[L], lazy[L];
void down(int n, int l, int r) {
if (lazy[n] > 0) {
cnt[n] = rc[n] = lc[n] = 1;
len[n] = r - l + 1;
}
else {
lc[n] = lc[ls];
rc[n] = rc[rs];
cnt[n] = cnt[ls] + cnt[rs];
if (rc[ls] == 1 && lc[rs] == 1) cnt[n]--;
len[n] = len[ls] + len[rs];
}
}
void add(int n, int l, int r, int a, int b, int v) {
if (a <= l && r <= b) {
lazy[n] += v;
if (l == r) {
if (lazy[n] > 0)
lazy[n] = rc[n] = lc[n] = cnt[n] = len[n] = 1;
else lazy[n] = rc[n] = lc[n] = cnt[n] = len[n] = 0;
}
else down(n, l, r);
}
else {
int mid = (l + r) >> 1;
if (b > mid) add(rs, mid + 1, r, a, b, v);
if (a <= mid) add(ls, l, mid, a, b, v);
down(n, l, r);
}
}
int main() {
scanf("%d %d", &l, &n);
for (int i = 1; i <= n; ++i) {
scanf("%d %d %d", &m, &a, &t);
if (m == 1) add(1, 1, l, a, a + t - 1, 1);
else add(1, 1, l, a, a + t - 1, -1);
printf("%d %d\n", cnt[1], len[1]);
}
return 0;
}
——CYCKN