想做这个事情已经比较久啦,代码能力非常需要保持,把各个板子都重新打一遍,汇总一下,方便之后做比赛自己的常用板子。
点修改,区间求和
#include <bits/stdc++.h>
const int maxn = 1e5+100;
using namespace std;
int a[maxn];
struct SegmentTree{
int l,r;
int sum;
}t[maxn*4];
void build(int o,int l,int r){
t[o].l=l,t[o].r=r;
if (l == r) {t[o].sum = a[l];return;}
int mid = (l+r)>>1;
build(o*2,l,mid),build(o*2+1,mid+1,r);
t[o].sum = t[o*2].sum+t[o*2+1].sum;
}
void update(int o,int x,int v){
if (t[o].l==t[o].r){t[o].sum += v;return;}
int mid = (t[o].l+t[o].r)>>1;
if (x <= mid) update(o*2,x,v);
else update(o*2+1,x,v);
t[o].sum=t[o*2].sum+t[o*2+1].sum;
}
int query(int o,int l,int r){
if (l<=t[o].l && r>=t[o].r) return t[o].sum;
int mid = (t[o].l+t[o].r) >> 1;
int sum=0;
if (l<=mid) sum+=query(o*2,l,r);
if (r>mid) sum+=query(o*2+1,l,r);
return sum;
}
int main(){
ios::sync_with_stdio(false);
cin.tie(0);
int n,m;
cin >> n ;
for (int i=1;i<=n;i++) cin >> a[i];
build(1,1,n);
cin >> m;
for (int i=1;i<=m;i++){
int flag;
cin >> flag;
if (flag == 2){
int x,y;
cin >> x >> y;
cout << query(1,x,y) << endl;
}
if (flag == 1){
int x,y;
cin >> x >> y;
update(1,x,y);
}
}
}
区间修改,区间查询
#include <bits/stdc++.h>
const int maxn = 2*1e5+100;
#define LL long long
using namespace std;
int a[maxn];
struct SegmentTree{
int l,r;
LL sum,add;
}t[maxn*4];
void build(int o,int l,int r){
t[o].l=l,t[o].r=r;
if (l == r) {t[o].sum = a[l];return;}
int mid = (l+r)>>1;
build(o*2,l,mid),build(o*2+1,mid+1,r);
t[o].sum = t[o*2].sum+t[o*2+1].sum;
}
void maintain(int o){
if (t[o].add){
t[o*2].sum += t[o].add*(t[o*2].r-t[o*2].l+1);
t[o*2+1].sum += t[o].add*(t[o*2+1].r-t[o*2+1].l+1);
t[o*2].add += t[o].add;
t[o*2+1].add += t[o].add;
t[o].add = 0;
}
}
void update(int o,int l,int r,int d){
if (l<=t[o].l && r>= t[o].r){
t[o].sum += d*(t[o].r-t[o].l+1);
t[o].add += d;
return;
}
maintain(o);
int mid = (t[o].l+t[o].r)>>1;
if (l <= mid) update(o*2,l,r,d);
if (r>mid) update(o*2+1,l,r,d);
t[o].sum=t[o*2].sum+t[o*2+1].sum;
}
LL query(int o,int l,int r){
if (l<=t[o].l && r>=t[o].r) return t[o].sum;
maintain(o);
int mid = (t[o].l+t[o].r) >> 1;
LL sum=0;
if (l<=mid) sum+=query(o*2,l,r);
if (r>mid) sum+=query(o*2+1,l,r);
return sum;
}
int main(){
ios::sync_with_stdio(false);
cin.tie(0);
int n,m;
cin >> n ;
for (int i=1;i<=n;i++) cin >> a[i];
build(1,1,n);
cin >> m;
for (int i=1;i<=m;i++){
int flag;
cin >> flag;
if (flag == 2){
int x,y;
cin >> x >> y;
cout << query(1,x,y) << endl;
}
if (flag == 1){
int x,y,z;
cin >> x >> y >> z;
update(1,x,y,z);
}
}
}
有难度的例题
洛谷P3373 【模板】线段树 2
题目描述
已知一个数列,你需要进行下面三种操作:
1.将某区间每一个数乘上x
2.将某区间每一个数加上x
3.求出某区间每一个数的和
感觉并不模板,考虑到乘法和加法的先后顺序,在maintain的时候注意按照sum,mul,add的顺序maintain,add在维护时先乘上mul再加。(可以仔细想想这是为什么)
AC代码
#include <bits/stdc++.h>
const int maxn = 2*1e5+100;
#define LL long long
#define ls o*2
#define rs ls+1
using namespace std;
LL a[maxn],p;
struct SegmentTree{
int l,r;
LL sum,add,mul;
}t[maxn*4];
void build(int o,int l,int r){
t[o].l=l,t[o].r=r;t[o].mul=1;
if (l == r) {t[o].sum = a[l];return;}
int mid = (l+r)>>1;
build(ls,l,mid),build(rs,mid+1,r);
t[o].sum = t[ls].sum+t[rs].sum;
}
void maintain(int o){
t[ls].sum = t[ls].sum*t[o].mul %p;
t[rs].sum = t[rs].sum*t[o].mul % p;
t[ls].sum += t[o].add*(t[ls].r-t[ls].l+1);
t[rs].sum += t[o].add*(t[rs].r-t[rs].l+1);
t[ls].mul = t[o].mul * t[ls].mul % p;
t[rs].mul = t[o].mul * t[rs].mul % p;
t[ls].add = t[o].mul *t[ls].add %p;
t[rs].add = t[o].mul * t[rs].add % p;
t[ls].add = (t[o].add +t[ls].add)%p;
t[rs].add = (t[o].add +t[rs].add)%p;
t[o].add = 0;
t[o].mul = 1;
}
void addf(int o,int l,int r,int d){
if (l<=t[o].l && r>= t[o].r){
t[o].sum += d*(t[o].r-t[o].l+1);
t[o].add += d;
return;
}
maintain(o);
int mid = (t[o].l+t[o].r)>>1;
if (l <= mid) addf(ls,l,r,d);
if (r>mid) addf(rs,l,r,d);
t[o].sum=t[ls].sum+t[rs].sum;
}
void mulf(int o,int l,int r,int d){
if (l<=t[o].l && r>= t[o].r){
t[o].sum = t[o].sum * d % p;
t[o].add = t[o].add * d % p;
t[o].mul = d * t[o].mul % p;
return;
}
maintain(o);
int mid = (t[o].l+t[o].r)>>1;
if (l <= mid) mulf(ls,l,r,d);
if (r>mid) mulf(rs,l,r,d);
t[o].sum=t[ls].sum+t[rs].sum;
}
LL query(int o,int l,int r){
if (l<=t[o].l && r>=t[o].r) return t[o].sum;
maintain(o);
int mid = (t[o].l+t[o].r) >> 1;
LL sum=0;
if (l<=mid) sum=(sum+query(ls,l,r))%p;
if (r>mid) sum=(query(rs,l,r)+sum)%p;
return sum;
}
int main(){
ios::sync_with_stdio(false);
cin.tie(0);
int n,m;
cin >> n >> m >> p;
for (int i=1;i<=n;i++) cin >> a[i];
build(1,1,n);
for (int i=1;i<=m;i++){
int flag;
cin >> flag;
if (flag == 3){
int x,y;
cin >> x >> y;
cout << query(1,x,y) << endl;
}
if (flag == 2){
int x,y,z;
cin >> x >> y >> z;
addf(1,x,y,z);
}
if (flag == 1){
int x,y,z;
cin >> x >> y >> z;
mulf(1,x,y,z);
}
}
}
CH4302 Interval GCD
题目描述
已知一个数列,你需要进行下面两种操作:
1.将某区间每一个数加上x
2.求出某区间的最大公约数
我们知道九章算术之更相减损术,gcd(x,y) = gcd(x,y-x),这个结论可以被扩展到更多个数的情况,gcd(x,y,z) = gcd(x,y-x,z-y),事实上,对于任意多个整数都适用。
因此,构造一个新的差分序列 B[i] = A[i] - A[i-1],B[1]可取任意值,
这样一来,询问就转化为了求出 gcd(A{l],ask_max(1,l+1,r)),即求出左端点和右边区间最大值的gcd。
我们再反过来看update怎么改变。因为是差分数列,区间加转化为了B[l]+=d,B[r+1]-=d,做两次单点修改。
着实是妙 啊!
CH4301
题目描述
已知一个数列,你需要进行下面两种操作:
1查询区间的最大连续字段和
2.单点修改值
在线段树的每个节点维护四个信息,区间和sum,区间最大连续字段和dat,紧靠左端的最大连续字段和 lmax,紧靠右端的最大连续字段和 rmax 。
SegmentTreed的真题框架不变,只需要在build和update函数中从下往上传递信息:
t[p].sum = t[ls].sum+t.[rs].sum;
t[p].lmax = max(t[ls].lmax,t[ls].sum+t[rs].lmax);
t[p].rmax = max(t[rs].rmax,t[rs].sum+t[ls].rmax);
t[p].dat = max(t[ls].dat,t[ rs].dat,t[ls].rmax+t[rs].lmax);
WA代码(施工中)
#include <bits/stdc++.h>
const int maxn = 5*1e5+100;
#define ls o*2
#define rs o*2+1
using namespace std;
int a[maxn];
struct SegmentTree{
int l,r;
int sum,dat,lmax,rmax;
}t[maxn*4];
int maxxx(int a,int b,int c){
int t=max(a,b);
t=max(t,c);
return t;
}
void build(int o,int l,int r){
t[o].l=l,t[o].r=r;
if (l == r) {
t[o].sum =t[o].lmax=t[o].rmax=t[o].dat= a[l];
return;
}
int mid = (l+r)>>1;
build(o*2,l,mid),build(o*2+1,mid+1,r);
t[o].sum = t[o*2].sum+t[o*2+1].sum;
t[o].lmax = max(t[ls].lmax,t[ls].sum+t[rs].lmax);
t[o].rmax = max(t[rs].rmax,t[rs].sum+t[ls].rmax);
t[o].dat = maxxx(t[ls].dat,t[rs].dat,t[ls].rmax+t[rs].lmax);
}
void update(int o,int x,int v){
if (t[o].l==t[o].r){
t[o].sum =t[o].lmax=t[o].rmax=t[o].dat=v;
return;
}
int mid = (t[o].l+t[o].r)>>1;
if (x <= mid) update(o*2,x,v);
else update(o*2+1,x,v);
t[o].sum=t[o*2].sum+t[o*2+1].sum;
t[o].lmax = max(t[ls].lmax,t[ls].sum+t[rs].lmax);
t[o].rmax = max(t[rs].rmax,t[rs].sum+t[ls].rmax);
t[o].dat = maxxx(t[ls].dat,t[ rs].dat,t[ls].rmax+t[rs].lmax);
}
int query(int o,int l,int r){
if (l<=t[o].l && r>=t[o].r) return maxxx(t[o].rmax,t[o].lmax,t[o].dat);
int mid = (t[o].l+t[o].r) >> 1;
int sum=-0x3f3f3f;
if (l<=mid) sum=max(query(o*2,l,r),sum);
if (r>mid) sum=max(query(o*2+1,l,r),sum);
return sum;
}
int main(){
ios::sync_with_stdio(false);
cin.tie(0);
int n,m;
cin >> n >> m;
for (int i=1;i<=n;i++) cin >> a[i];
build(1,1,n);
// for (int i=1;i<=n;i++)cout << query(1,i,i)<<' ';
// cout << endl;
for (int i=1;i<=m;i++){
int flag;
cin >> flag;
if (flag == 1){
int x,y;
cin >> x >> y;
if (x>y){int t=x;x=y;y=t;}
cout << query(1,x,y) << endl;
}
if (flag == 2){
int x,y;
cin >> x >> y;
update(1,x,y);
}
}
// for (int i=1;i<=n;i++)cout << query(1,i,i)<<' ';
}