题意:给一个大小为n的数组,动态第改变其中的值,求最大子段和,这里子段可以由尾部到首部(比如1,2,3,4,5,这里子段可以是2,3,4也可以是5,1,2)。
思路:用线段树可以求动态序列的最大子段和,这题的关键是处理从尾部到首部的情况,其实这种情况只要用总和减去最小子段和就可以了。所以用线段树维护最大子段和以及最小子段和就可以了。
代码:
//poj 2750
//sepNINE
#include <stdio.h>
#define maxN 100024
struct node {
int r,l,mid;
int sum,rmax,rmin,lmax,lmin,mmax,mmin;
}T[maxN*3];
int a[maxN];
int max(int a,int b)
{
return a<b?b:a;
}
int min(int a,int b)
{
return a<b?a:b;
}
void build(int l, int r, int k)
{
int mid=(l+r)/2;
T[k].l=l; T[k].r=r; T[k].mid=mid;
if(l==r){
T[k].sum=a[l]; T[k].rmax=a[l]; T[k].rmin=a[l];
T[k].lmax=a[l]; T[k].lmin=a[l];
T[k].mmax=a[l]; T[k].mmin=a[l];
return ;
}
build( l, mid, 2*k );
build( mid+1, r, 2*k+1 );
T[k].sum=T[2*k].sum+T[2*k+1].sum;
T[k].rmax=max( T[2*k+1].rmax, T[2*k].rmax+T[2*k+1].sum );
T[k].rmin=min( T[2*k+1].rmin, T[2*k].rmin+T[2*k+1].sum );
T[k].lmax=max( T[2*k].lmax, T[2*k+1].lmax+T[2*k].sum );
T[k].lmin=min( T[2*k].lmin, T[2*k+1].lmin+T[2*k].sum );
int a,b;
a=max( T[2*k].mmax, T[2*k+1].mmax );
b=max( T[k].lmax, T[k].rmax );
T[k].mmax=max( a, b );
T[k].mmax=max( T[k].mmax, T[2*k].rmax+T[2*k+1].lmax );
a=min( T[2*k].mmin, T[2*k+1].mmin );
b=min( T[k].lmin, T[k].rmin );
T[k].mmin=min( a, b );
T[k].mmin=min( T[k].mmin, T[2*k].rmin+T[2*k+1].lmin );
return ;
}
void modify(int x, int y, int k)
{
if(T[k].l==T[k].r){
T[k].sum=y; T[k].rmax=y; T[k].rmin=y;
T[k].lmax=y; T[k].lmin=y;
T[k].mmax=y; T[k].mmin=y;
return ;
}
int mid=T[k].mid;
if(x<=mid)
modify(x,y,2*k);
else
modify(x,y,2*k+1);
T[k].sum=T[2*k].sum+T[2*k+1].sum;
T[k].rmax=max( T[2*k+1].rmax, T[2*k].rmax+T[2*k+1].sum );
T[k].rmin=min( T[2*k+1].rmin, T[2*k].rmin+T[2*k+1].sum );
T[k].lmax=max( T[2*k].lmax, T[2*k+1].lmax+T[2*k].sum );
T[k].lmin=min( T[2*k].lmin, T[2*k+1].lmin+T[2*k].sum );
int a,b;
a=max( T[2*k].mmax, T[2*k+1].mmax );
b=max( T[k].lmax, T[k].rmax );
T[k].mmax=max( a, b );
T[k].mmax=max( T[k].mmax, T[2*k].rmax+T[2*k+1].lmax );
a=min( T[2*k].mmin, T[2*k+1].mmin );
b=min( T[k].lmin, T[k].rmin );
T[k].mmin=min( a, b );
T[k].mmin=min( T[k].mmin, T[2*k].rmin+T[2*k+1].lmin );
return ;
}
int main() {
int i,n,m;
scanf("%d", &n);
for( i=1; i<=n; ++i )
scanf("%d" , &a[i]);
build( 1, n, 1);
scanf("%d", &m);
while(m--){
int x,y;
scanf("%d%d", &x, &y);
modify( x, y, 1);
if(T[1].mmin<0)
printf("%d\n", max( T[1].mmax, T[1].sum-T[1].mmin));
else
printf("%d\n",T[1].sum-T[1].mmin);
}
return 0;
}