用线段树维护,每次取出最大子段和并把这一段区间取反,重复
k
次。复杂度
有一个比较巧妙的证明,按照区间选择的模型建最大费用流,每次增广最长路就是选择最大的区间并取反。因为费用流的做法是对的,所以这样贪心就是对的。
#include<cstdio>
#include<algorithm>
using namespace std;
const int maxn=1000010;
int a[maxn],ll[40],rr[40],flag,n;
struct node
{
int sum,mx,lmx,rmx,mn,lmn,rmn,inv,
l,r,mxl,mxr,lmxr,rmxl,lmnr,rmnl,mnl,mnr;
void init(int p)
{
l=r=mxl=mxr=mnl=mnr=lmxr=lmnr=rmxl=rmnl=p;
sum=mx=lmx=rmx=mn=lmn=rmn=a[p];
}
void upd(node L,node R)
{
l=L.l;
r=R.r;
sum=L.sum+R.sum;
if (L.lmx>L.sum+R.lmx) lmx=L.lmx,lmxr=L.lmxr;
else lmx=L.sum+R.lmx,lmxr=R.lmxr;
if (L.lmn<L.sum+R.lmn) lmn=L.lmn,lmnr=L.lmnr;
else lmn=L.sum+R.lmn,lmnr=R.lmnr;
if (R.rmx>R.sum+L.rmx) rmx=R.rmx,rmxl=R.rmxl;
else rmx=R.sum+L.rmx,rmxl=L.rmxl;
if (R.rmn<R.sum+L.rmn) rmn=R.rmn,rmnl=R.rmnl;
else rmn=R.sum+L.rmn,rmnl=L.rmnl;
if (L.mx>R.mx&&L.mx>L.rmx+R.lmx) mx=L.mx,mxl=L.mxl,mxr=L.mxr;
else if (R.mx>L.rmx+R.lmx) mx=R.mx,mxl=R.mxl,mxr=R.mxr;
else mx=L.rmx+R.lmx,mxl=L.rmxl,mxr=R.lmxr;
if (L.mn<R.mn&&L.mn<L.rmn+R.lmn) mn=L.mn,mnl=L.mnl,mnr=L.mnr;
else if (R.mn<L.rmn+R.lmn) mn=R.mn,mnl=R.mnl,mnr=R.mnr;
else mn=L.rmn+R.lmn,mnl=L.rmnl,mnr=R.lmnr;
}
void clear()
{
inv=0;
sum=-sum;
swap(lmx,lmn);
lmx=-lmx;
lmn=-lmn;
swap(lmxr,lmnr);
swap(rmx,rmn);
rmx=-rmx;
rmn=-rmn;
swap(rmxl,rmnl);
swap(mx,mn);
mx=-mx;
mn=-mn;
swap(mxl,mnl);
swap(mxr,mnr);
}
}t[maxn],now;
void build(int u,int L,int R)
{
if (L==R) t[u].init(L);
else
{
int mid=L+R>>1;
build(u<<1,L,mid);
build(u<<1|1,mid+1,R);
t[u].upd(t[u<<1],t[u<<1|1]);
}
}
void down(int u)
{
if (t[u].inv)
{
t[u].clear();
if (t[u].l<t[u].r)
{
t[u<<1].inv^=1;
t[u<<1|1].inv^=1;
}
}
}
void modify(int u,int p)
{
down(u);
if (t[u].l==t[u].r) t[u].init(p);
else
{
int mid=t[u].l+t[u].r>>1;
if (p<=mid) modify(u<<1,p);
else modify(u<<1|1,p);
down(u<<1);
down(u<<1|1);
t[u].upd(t[u<<1],t[u<<1|1]);
}
}
void find(int u,int l,int r)
{
down(u);
if (l<=t[u].l&&t[u].r<=r)
{
if (!flag) now=t[u],flag=1;
else now.upd(now,t[u]);
return;
}
int mid=t[u].l+t[u].r>>1;
if (l<=mid) find(u<<1,l,r);
if (r>mid) find(u<<1|1,l,r);
}
void mark(int u,int l,int r)
{
down(u);
if (l<=t[u].l&&t[u].r<=r) t[u].inv=1;
else
{
int mid=t[u].l+t[u].r>>1;
if (l<=mid) mark(u<<1,l,r);
if (r>mid) mark(u<<1|1,l,r);
down(u<<1);
down(u<<1|1);
t[u].upd(t[u<<1],t[u<<1|1]);
}
}
int main()
{
int x,y,k,q,opt,ans,cnt;
scanf("%d",&n);
for (int i=1;i<=n;i++) scanf("%d",&a[i]);
build(1,1,n);
scanf("%d",&q);
while (q--)
{
scanf("%d%d%d",&opt,&x,&y);
if (opt==0) a[x]=y,modify(1,x);
else
{
scanf("%d",&k);
ans=cnt=0;
while (k--)
{
flag=0;
find(1,x,y);
if (now.mx>0)
{
ans+=now.mx;
cnt++;
ll[cnt]=now.mxl;
rr[cnt]=now.mxr;
mark(1,now.mxl,now.mxr);
}
else break;
}
printf("%d\n",ans);
for (;cnt;cnt--) mark(1,ll[cnt],rr[cnt]);
}
}
}