经典线段树问题,需要用上懒操作
对于每个区间,记录该区间从左端点开始连续的1和0的个数,从右端点开始连续的1和0的个数,整个区间最多的连续1和0的个数,以及这个区间是否被翻转过
每次访问子区间时,如果父区间被翻转过,子区间也要被翻转,标记下移,更新完子区间后再向上更新
查询函数中不用向上更新,因为如果子区间被翻转,在更新函数中父区间也肯定被翻转过或更新过,父区间的信息已经被修改,就算向上更新也不会改变父区间的状态
代码:
#include<iostream>
#include<memory.h>
#include<string>
#include<cstdio>
#include<algorithm>
#include<math.h>
#include<stack>
#include<queue>
#include<vector>
#include<map>
#include<ctime>
using namespace std;
const int MAX=100005;
struct node
{
int l,r,cnt,len;
int wl,wr,wv,bl,br,bv;
}t[MAX*5];
int a[MAX],n,m;
void build(int ll,int rr,int n)
{
t[n].l=ll; t[n].r=rr; t[n].len=rr-ll+1;
t[n].cnt=0;
if(ll==rr)
{
if(a[ll]==0)
{
t[n].wl=t[n].wr=t[n].wv=1;
t[n].bl=t[n].br=t[n].bv=0;
}
else
{
t[n].wl=t[n].wr=t[n].wv=0;
t[n].bl=t[n].br=t[n].bv=1;
}
return;
}
int mid=(ll+rr)/2;
build(ll,mid,n*2);
build(mid+1,rr,n*2+1);
if(t[n*2].bl==t[n*2].len&&a[t[n*2].r+1]==1)
t[n].bl=t[n*2].bl+t[n*2+1].bl;
else
t[n].bl=t[n*2].bl;
if(t[n*2+1].br==t[n*2+1].len&&a[t[n*2].r]==1)
t[n].br=t[n*2+1].br+t[n*2].br;
else
t[n].br=t[n*2+1].br;
t[n].bv=max(t[n*2].bv,t[n*2+1].bv);
t[n].bv=max(t[n].bv,t[n*2].br+t[n*2+1].bl);
if(t[n*2].wl==t[n*2].len&&a[t[n*2].r+1]==0)
t[n].wl=t[n*2].wl+t[n*2+1].wl;
else
t[n].wl=t[n*2].wl;
if(t[n*2+1].wr==t[n*2+1].len&&a[t[n*2].r]==0)
t[n].wr=t[n*2+1].wr+t[n*2].wr;
else
t[n].wr=t[n*2+1].wr;
t[n].wv=max(t[n*2].wv,t[n*2+1].wv);
t[n].wv=max(t[n].wv,t[n*2].wr+t[n*2+1].wl);
}
void change1(int n)
{
swap(t[n].wl,t[n].bl);
swap(t[n].wr,t[n].br);
swap(t[n].wv,t[n].bv);
}
void change(int n)
{
if(t[n].l==t[n].r)
return;
if(t[n*2].bl==t[n*2].len)
t[n].bl=t[n*2].bl+t[n*2+1].bl;
else
t[n].bl=t[n*2].bl;
if(t[n*2+1].br==t[n*2+1].len)
t[n].br=t[n*2+1].br+t[n*2].br;
else
t[n].br=t[n*2+1].br;
t[n].bv=max(t[n*2].bv,t[n*2+1].bv);
t[n].bv=max(t[n].bv,t[n*2].br+t[n*2+1].bl);
if(t[n*2].wl==t[n*2].len)
t[n].wl=t[n*2].wl+t[n*2+1].wl;
else
t[n].wl=t[n*2].wl;
if(t[n*2+1].wr==t[n*2+1].len)
t[n].wr=t[n*2+1].wr+t[n*2].wr;
else
t[n].wr=t[n*2+1].wr;
t[n].wv=max(t[n*2].wv,t[n*2+1].wv);
t[n].wv=max(t[n].wv,t[n*2].wr+t[n*2+1].wl);
}
void update(int ll,int rr,int n,int val)
{
if(t[n].l==ll&&t[n].r==rr)
{
t[n].cnt^=val;
change1(n);
return;
}
if(t[n].cnt)
{
t[n*2].cnt^=t[n].cnt; t[n*2+1].cnt^=t[n].cnt;
change1(n*2);
change1(n*2+1);
t[n].cnt=0;
}
int mid=(t[n].l+t[n].r)/2;
if(mid>=rr)
update(ll,rr,n*2,val);
else if(mid<ll)
update(ll,rr,n*2+1,val);
else
{
update(ll,mid,n*2,val);
update(mid+1,rr,n*2+1,val);
}
change(n);
}
int query(int ll,int rr,int n)
{
if(t[n].l==ll&&t[n].r==rr)
{
return t[n].bv;
}
if(t[n].cnt)
{
t[n*2].cnt^=t[n].cnt; t[n*2+1].cnt^=t[n].cnt;
change1(n*2);
change1(n*2+1);
t[n].cnt=0;
}
int mid=(t[n].l+t[n].r)/2;
if(mid>=rr)
return query(ll,rr,n*2);
else if(mid<ll)
return query(ll,rr,n*2+1);
else
{
int x=query(ll,mid,n*2);
int y=query(mid+1,rr,n*2+1);
int z=0;
z=min(rr,t[n*2+1].l+t[n*2+1].bl-1)-max(ll,t[n*2].r-t[n*2].br+1)+1;
return max(x,max(y,z));
}
}
int main()
{
int i,j,k,x;
while(scanf("%d",&n)!=EOF)
{
for(i=1;i<=n;i++)
scanf("%d",&a[i]);
build(1,n,1);
scanf("%d",&m);
while(m--)
{
scanf("%d%d%d",&x,&i,&j);
if(x==0)
{
printf("%d\n",query(i,j,1));
}
else
{
if(i>j)
swap(i,j);
update(i,j,1,1);
}
}
}
return 0;
}