解题思路
我们对线段树的每个节点维护几个信息:sum(区间和),maxx(区间最大子段和),lmax(从区间左端点开始的最大子段和),rmax(从区间右端点开始的最大子段和)。用fa表示当前区间,lson,rson表示其左右区间,我们可以得到如下转移:
- s u m [ f a ] = s u m [ l s o n ] + s u m [ r s o n ] sum[fa]=sum[lson]+sum[rson] sum[fa]=sum[lson]+sum[rson]
- 最大子段和要么全在左区间,右区间,或者左右均有,取最大转移:
m a x x [ f a ] = m a x ( m a x ( m a x x [ l s o n ] , m a x x [ r s o n ] ) , r m a x [ l s o n ] + l m a x [ r s o n ] ) maxx[fa]=max(max(maxx[lson],maxx[rson]),rmax[lson]+lmax[rson]) maxx[fa]=max(max(maxx[lson],maxx[rson]),rmax[lson]+lmax[rson]) - 从左端点开始的最大子段和要么在左区间,要么是左区间的全部加上右区间中以左端点开始的最大子段和:
l m a x = m a x ( l m a x [ l s o n ] , s u m [ l s o n ] + l m a x [ r s o n ] ) lmax=max(lmax[lson],sum[lson]+lmax[rson]) lmax=max(lmax[lson],sum[lson]+lmax[rson]) - 同理得:
r m a x = m a x ( r m a x [ r s o n ] , s u m [ r s o n ] + r m a x [ l s o n ] ) rmax=max(rmax[rson],sum[rson]+rmax[lson]) rmax=max(rmax[rson],sum[rson]+rmax[lson])
代码
#include<iostream>
#include<cstdio>
#include<iomanip>
#include<cstring>
#include<algorithm>
#include<cmath>
#include<queue>
using namespace std;
int n,m,K,p,q,a[500010];
struct c{
long long l,r,maxx,num;
}tree[2000010];
c add(c x,c y)
{
c z;
z.num=x.num+y.num;
z.l=max(x.l,x.num+y.l);
z.r=max(y.r,x.r+y.num);
z.maxx=max(max(x.maxx,y.maxx),x.r+y.l);
return z;
}
void build(int k,int l,int r)
{
if(l==r)
{
tree[k].num=a[l];
tree[k].l=a[l];
tree[k].r=a[l];
tree[k].maxx=a[l];
return;
}
int mid=(l+r)/2;
build(k*2,l,mid);
build(k*2+1,mid+1,r);
tree[k]=add(tree[k*2],tree[k*2+1]);
}
void change(int k,int l,int r,int x,int v){
if(l==r)
{
tree[k].num=v;
tree[k].l=v;
tree[k].r=v;
tree[k].maxx=v;
return;
}
int mid=(l+r)/2;
if(x<=mid)change(k*2,l,mid,x,v);
else change(k*2+1,mid+1,r,x,v);
tree[k]=add(tree[k*2],tree[k*2+1]);
}
c query(int k,int l,int r,int x,int y){
if(l>=x&&r<=y)return tree[k];
int mid=(l+r)/2;
if(y<=mid)return query(k*2,l,mid,x,y);
if(x>mid)return query(k*2+1,mid+1,r,x,y);
c a,b,c;
a=query(k*2,l,mid,x,y);
b=query(k*2+1,mid+1,r,x,y);
c=add(a,b);
return c;
}
int main(){
scanf("%d%d",&n,&m);
for(int i=1;i<=n;i++)
scanf("%d",&a[i]);
build(1,1,n);
for(int i=1;i<=m;i++)
{
scanf("%d%d%d",&K,&p,&q);
if(K==2)
change(1,1,n,p,q);
if(K==1)
{
if(p>q)swap(p,q);
c ans=query(1,1,n,p,q);
printf("%lld\n",ans.maxx);
}
}
}