线段树&树状数组练习题
前导知识
:
1. 【模板】线段树1
#include <iostream>
using namespace std;
#define int long long
int n,q,a[100005],d[270000],b[270000];
void build(int s,int t,int p){
/* 对[s,t]区间建立线段树,当前根的编号为p */
if(s==t){
d[p]=a[s];
return;
}
int m=s+((t-s)>>1);
build(s,m,p*2),build(m+1,t,p*2+1);
/* 递归对左右区间建树 */
d[p]=d[p*2]+d[(p*2)+1];
}
void update(int l,int r,int c,int s,int t,int p){
/* [l,r] 为修改区间,c 为被修改的元素的变化量,[s,t] 为当前节点包含的区间,p为当前节点的编号 */
if(l<=s&&t<=r){
d[p]+=(t-s+1)*c,b[p]+=c;
return;
}
/* 当前区间为修改区间的子集时直接修改当前节点的值,然后打标记,结束修改 */
int m=s+((t-s)>>1);
if(b[p]&&s!=t){
/* 如果当前节点的懒标记非空,则更新当前节点两个子节点的值和懒标记值 */
d[p*2]+=b[p]*(m-s+1),d[p*2+1]+=b[p]*(t-m);
b[p*2]+=b[p],b[p*2+1]+=b[p]; /* 将标记下传给子节点 */
b[p]=0; /* 清空当前节点的标记 */
}
if (l <= m) update(l, r, c, s, m, p * 2);
if (r > m) update(l, r, c, m + 1, t, p * 2 + 1);
d[p] = d[p * 2] + d[p * 2 + 1];
}
int getsum(int l, int r, int s, int t, int p) {
/* [l,r] 为查询区间,[s,t] 为当前节点包含的区间,p为当前节点的编号 */
if (l <= s && t <= r) return d[p];
/* 当前区间为询问区间的子集时直接返回当前区间的和 */
int m = s + ((t - s) >> 1);
if (b[p]) {
/* 如果当前节点的懒标记非空,则更新当前节点两个子节点的值和懒标记值 */
d[p * 2] += b[p] * (m - s + 1), d[p * 2 + 1] += b[p] * (t - m),
b[p * 2] += b[p], b[p * 2 + 1] += b[p]; /* 将标记下传给子节点 */
b[p] = 0; /* 清空当前节点的标记 */
}
int sum = 0;
if (l <= m) sum = getsum(l, r, s, m, p * 2);
if (r > m) sum += getsum(l, r, m + 1, t, p * 2 + 1);
return sum;
}
signed main(){
scanf("%lld %lld",&n,&q);
for(int i=1;i<=n;i++) scanf("%lld",&a[i]);
build(1,n,1);
while(q--){
int i1,i2,i3,i4;
scanf("%lld %lld %lld",&i1,&i2,&i3);
if(i1==2)
printf("%lld\n",getsum(i2,i3,1,n,1));
else
scanf("%lld",&i4),update(i2,i3,i4,1,n,1);
}
return 0;
}
2. 【模板】线段树2
#include <cstdio>
#define int long long
int n,m;
int mod;
int opt,x,y,z;
int a[100005],sum[400005],mul[400005],laz[400005];
void push_up(int i){
sum[i]=(sum[(i<<1)]+sum[(i<<1)|1])%mod;
}
void push_down(int i,int s,int t){
int l=(i<<1),r=(i<<1)|1,mid=(s+t)>>1;
if(mul[i]!=1){
mul[l] *= mul[i];
mul[l] %= mod;
mul[r] *= mul[i];
mul[r] %= mod;
laz[l] *= mul[i];
laz[l] %= mod;
laz[r] *= mul[i];
laz[r] %= mod;
sum[l] *= mul[i];
sum[l] %= mod;
sum[r] *= mul[i];
sum[r] %= mod;
mul[i] = 1;
}
if (laz[i]) {
sum[l] += laz[i] * (mid - s + 1);
sum[l] %= mod;
sum[r] += laz[i] * (t - mid);
sum[r] %= mod;
laz[l] += laz[i];
laz[l] %= mod;
laz[r] += laz[i];
laz[r] %= mod;
laz[i] = 0;
}
return;
}
void build(int s,int t,int i){
mul[i]=1;
if(s==t){
sum[i]=a[s];
return;
}
int mid=s+((t-s)>>1);
build(s,mid,i<<1);
build(mid+1,t,(i<<1)|1);
push_up(i);
}
void multiply(int l,int r,int z,int s,int t,int i){
int mid=s+((t-s)>>1);
if(l<=s&&t<=r){
mul[i] *= z;
mul[i] %= mod;
laz[i] *= z;
laz[i] %= mod;
sum[i] *= z;
sum[i] %= mod;
return;
}
push_down(i,s,t);
if(mid>=l) multiply(l,r,z,s,mid,(i<<1));
if(mid+1<=r) multiply(l,r,z,mid+1,t,(i<<1)|1);
push_up(i);
}
void add(int l,int r,int z,int s,int t,int i){
int mid=s+((t-s)>>1);
if(l<=s&&t<=r){
sum[i]+=z*(t-s+1);
sum[i]%=mod;
laz[i]+=z;
laz[i]%=mod;
return;
}
push_down(i,s,t);
if(mid>=l) add(l,r,z,s,mid,(i<<1));
if(mid+1<=r) add(l,r,z,mid+1,t,(i<<1)|1);
push_up(i);
}
int getans(int l,int r,int s,int t,int i){
int mid=s+((t-s)>>1);
int tot=0;
if(l<=s&&t<=r) return sum[i];
push_down(i,s,t);
if(mid>=l) tot+=getans(l,r,s,mid,(i<<1));
tot%=mod;
if(mid+1<=r) tot+=getans(l,r,mid+1,t,(i<<1)|1);
return tot%mod;
}
signed main(){
scanf("%lld %lld %lld",&n,&m,&mod);
for(int i=1;i<=n;i++) scanf("%lld",&a[i]);
build(1,n,1);
for(int i=1;i<=m;i++){
scanf("%lld",&opt);
if(opt==1){
scanf("%lld %lld %lld",&x,&y,&z);
multiply(x,y,z,1,n,1);
}else if(opt==2){
scanf("%lld %lld %lld",&x,&y,&z);
add(x,y,z,1,n,1);
}else if(opt==3){
scanf("%lld %lld",&x,&y);
printf("%lld\n",getans(x,y,1,n,1));
}
}
return 0;
}
3. 逆序对
#include <iostream>
#include<algorithm>
using namespace std;
#define int long long
const int N = 5e5 + 5;
const int INF = 0x7ffffff;
int n;
int a[N],tmp[N];
template <typename T>
inline void read(T &x) {
register T c = getchar();
for (; c < 48 || 57 < c; c = getchar())
;
for (; 48 <= c && c <= 57; c = getchar())
x = (x << 3) + (x << 1) + (c & 15);
}
template <typename T>
inline void print(T x) {
if (x > 9)
print(x / 10);
putchar(x % 10 | 48);
}
int CDQ(int L,int R){
if(L==R) return 0;
int mid=L + R >>1;
int ans=CDQ(L,mid)+CDQ(mid+1,R);
int i=L,j=mid+1,k=L;
while(i<=mid&&j<=R){
if(a[i]<=a[j]) tmp[k++]=a[i++];
else tmp[k++]=a[j++],ans+=mid-i+1;
}
while(i<=mid) tmp[k++]=a[i++];
while(j<=R) tmp[k++]=a[j++];
for(int i=L;i<=R;i++){
a[i]=tmp[i];
}
return ans;
}
signed main(){
read(n);
for(int i=1;i<=n;i++){
read(a[i]);
}
print(CDQ(1,n));
return 0;
}
4. 无聊的数列
#include <cstdio>
#define int long long
const int N = 1e5 + 5;
int n,m;
int a[N];
int opt;
int l,r,K,D;
int p;
signed main(){
scanf("%d%d",&n,&m);
for(int i=1;i<=n;i++){
scanf("%d",&a[i]);
}
while(m--){
scanf("%d",&opt);
if(opt==1){
scanf("%d%d%d%d",&l,&r,&K,&D);
for(int i=l;i<=r;i++){
a[i]+=K+(i-l)*D;
}
}else if(opt==2){
scanf("%d",&p);
printf("%d\n",a[p]);
}
}
return 0;
}