线段树例题&代码详解
终于有空写题解了,小学课业真烦
关于线段树的初步认识,可以看我的博客,也就是这里
话不多说,先看例题一:
洛谷P3372,模板线段树1
首先分析一下,这道题还简单,是仅仅让我们进行区间查询与区间加法操作
简单复习下思路:
因为是要给区间 [l,r] 进行加法操作,那么我们可以计算一下数学公式:
[
l
,
r
]
+
l
z
=
(
r
−
l
+
1
)
∗
l
z
[l,r]\ +\ lz \\ =(r-l+1)\ *\ lz
[l,r] + lz=(r−l+1) ∗ lz
代码:
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;//不开long long 见祖宗
const int MAXN=1e5+5;
ll a[MAXN];
struct NODE{
ll l,r,sum,lz;
}node[MAXN<<2];
ll n,m,op,x,y,z;
void build(ll l,ll r,ll x)//建树
{
if(l==r)
{
node[x].sum=a[r];
return;
}
ll mid=(l+r)/2;
build(l,mid,x*2);
build(mid+1,r,x*2+1);
node[x].sum+=node[x*2+1].sum+node[x*2].sum;
return;
}
ll getnum(ll l,ll r,ll s,ll t,ll x)//计算
{
if(l<=s&&r>=t)
{
return node[x].sum;
}
ll mid,sum=0;
mid=(s+t)/2;
if(node[x].lz!=0)
{
node[x*2].sum+=(mid-s+1)*node[x].lz;
node[x*2+1].sum+=(t-mid)*node[x].lz;
node[x*2].lz+=node[x].lz;
node[x*2+1].lz+=node[x].lz;
}
node[x].lz=0;
if(l<=mid)
{
sum=getnum(l,r,s,mid,x*2);
}
if(r>mid)
{
sum+=getnum(l,r,mid+1,t,x*2+1);
}
return sum;
}
void update(ll l,ll r,ll c,ll s,ll t,ll p)//修改区间数值
{
if(l<=s&&t<=r)
{
node[p].sum+=(t-s+1)*c;
node[p].lz+=c;
return;
}
ll mid=(s+t)/2;
if(node[p].lz!=0)
{
node[p*2].sum+=(mid-s+1)*node[p].lz;
node[p*2+1].sum+=(t-mid)*node[p].lz;
node[p*2].lz+=node[p].lz;
node[p*2+1].lz+=node[p].lz;
}
node[p].lz=0;
if(l<=mid)
{
update(l,r,c,s,mid,p*2);
}
if(r>mid)
{
update(l,r,c,mid+1,t,p*2+1);
}
node[p].sum=node[p*2].sum+node[p*2+1].sum;
return;
}
int main(){
scanf("%lld%lld",&n,&m);
for(ll i=1;i<=n;i++)
{
scanf("%lld",&a[i]);
}
build(1,n,1);
while(m--)
{
scanf("%d%d%d",&op,&x,&y);
if(op==1)
{
scanf("%lld",&z);
update(x,y,z,1,n,1);
}
else
{
printf("%lld\n",getnum(x,y,1,n,1));
}
}
return 0;
}
例题二:
洛谷P3373,线段树2
分析下题目,发现这题与上面区别就是多了个区间乘法,当然懒标记下传也是有区别的
我们先来推断一下数学公式,还是设需要更改的区间为[l,r],修改值lc:
[
l
,
r
]
∗
l
c
=
(
a
l
+
a
l
+
1
+
a
l
+
2
+
.
.
.
+
a
r
−
1
+
a
r
)
∗
l
c
=
∑
l
r
∗
l
c
=
s
u
m
∗
l
c
[l,r]\ *\ lc \\ =\ ( a_l\ +\ a_{l+1}\ +\ a_{l+2}\ +\ ...\ +\ a_{r-1} \ +\ a_r ) *\ lc \\ =\sum _{l}^{r}\ *\ lc \\ =\ sum\ *\ lc
[l,r] ∗ lc= (al + al+1 + al+2 + ... + ar−1 + ar)∗ lc=l∑r ∗ lc= sum ∗ lc
以及在下传懒标记时lz需要为0,lc则为1
代码:
#include<bits/stdc++.h>
using namespace std;
const int MAXN=1e5+5;
typedef long long ll;
int n,m,mod;
int op,x,y,z;
ll a[MAXN];
struct NODE{
ll l,r,sum,lz,lc;
}node[MAXN<<2];
void pushdown(int p)
{
if(node[p].lc!=1)
{
node[p<<1].sum=(node[p].lc*node[p<<1].sum)%mod;
node[(p<<1)+1].sum=(node[p].lc*node[(p<<1)+1].sum)%mod;
node[p<<1].lz=(node[p<<1].lz*node[p].lc)%mod;
node[(p<<1)+1].lz=(node[(p<<1)+1].lz*node[p].lc)%mod;
node[p<<1].lc=(node[p<<1].lc*node[p].lc)%mod;
node[(p<<1)+1].lc=(node[(p<<1)+1].lc*node[p].lc)%mod;
}
if(node[p].lz!=0)
{
node[p<<1].sum=(node[p<<1].sum+(node[p<<1].r-node[p<<1].l+1)*node[p].lz);
node[(p<<1)+1].sum=(node[(p<<1)+1].sum+(node[(p<<1)+1].r-node[(p<<1)+1].l+1)*node[p].lz)%mod;
node[p<<1].lz=(node[p<<1].lz+node[p].lz)%mod;
node[(p<<1)+1].lz=(node[(p<<1)+1].lz+node[p].lz)%mod;
}
node[p].lz=0;
node[p].lc=1;
return;
}
void build(int l,int r,int p)
{
node[p].l=l;
node[p].r=r;
node[p].lc=1;
if(l==r)
{
node[p].sum=a[l];
return;
}
int mid=(l+r)/2;
build(l,mid,p<<1);
build(mid+1,r,(p<<1)+1);
node[p].sum=(node[p<<1].sum+node[(p<<1)+1].sum)%mod;
return;
}
void add(int l,int r,ll c,int p)
{
if(l>node[p].r||r<node[p].l) return;
if(l<=node[p].l&&node[p].r<=r)
{
node[p].lz=(node[p].lz+c)%mod;
node[p].sum=(node[p].sum+(node[p].r-node[p].l+1)*c)%mod;
return;
}
int mid=(l+r)/2;
pushdown(p);
add(l,r,c,p<<1);
add(l,r,c,(p<<1)+1);
node[p].sum=(node[p<<1].sum+node[(p<<1)+1].sum)%mod;
return;
}
void mul(int l,int r,ll c,int p)
{
if(node[p].l>r||node[p].r<l) return;
if(l<=node[p].l&&r>=node[p].r)
{
node[p].lz=(node[p].lz*c)%mod;
node[p].lc=(node[p].lc*c)%mod;
node[p].sum=(node[p].sum*c)%mod;
return;
}
pushdown(p);
mul(l,r,c,p<<1);
mul(l,r,c,(p<<1)+1);
node[p].sum=(node[p<<1].sum+node[(p<<1)+1].sum)%mod;
return;
}
ll getnum(int l,int r,int p)
{
if(node[p].l>r||node[p].r<l) return 0;
if(node[p].l>=l&&node[p].r<=r)return node[p].sum;
pushdown(p);
return (getnum(l,r,p<<1)+getnum(l,r,(p<<1)+1))%mod;
}
int main(){
scanf("%d%d%d",&n,&m,&mod);
for(int i=1;i<=n;i++)
scanf("%lld",&a[i]);
build(1,n,1);
for(int i=1;i<=m;i++)
{
scanf("%d%d%d",&op,&x,&y);
if(op==1)
{
scanf("%d",&z);
mul(x,y,z,1);
}
else
{
if(op==2)
{
scanf("%d",&z);
add(x,y,z,1);
}
else printf("%lld\n",getnum(x,y,1));
}
}
return 0;
}