题意简述
维护一个序列,支持如下操作
- 把[a, b]区间内的所有数全变成0
- 把[a, b]区间内的所有数全变成1
- 把[a,b]区间内所有的0变成1,所有的1变成0
- 询问[a, b]区间内总共有多少个1
- 询问[a, b]区间内最多有多少个连续的1
题解思路
线段树
对于每个节点,维护对应区间
- sum:1的个数
- L0:连续0的最大长度
- L1:连续1的最长长度
- l0:包含区间左端点的连续0的最大长度
- l1:包含区间左端点的连续1的最大长度
- r0:包含区间右端点的连续0的最大长度
- r1:包含区间右端点的连续1的最大长度
- la:-1表示不变,0表示全为0,1表示全为1
- tn:0表示不变,1表示翻转
可推出关系
ls表示左儿子,rs表示右儿子,lenl表示左儿子区间长度,lenr表示右儿子区间长度
\(L0 = max(L0_{ls}, L0_{rs}, r0_{ls}+l0_{rs})\)
\(l0 = l0_{ls} + l0_{rs} * (l0_{ls} == lenl)\)
\(r0 = r0_{rs} + r0_{ls} * (r0_{rs} == lenr)\)
L1,l1,r1同理可得。
下传标记时先判断la,因为全都变为1个数可以覆盖翻转的结果
代码
#include <cstdio>
#include <algorithm>
#define ci const int
#define ls x<<1
#define rs x<<1|1
#define mid ((l+r)>>1)
using std::max; using std::swap;
int n,m,opt,a,b,t;
struct Node1 { int sum,L0,L1,l0,l1,r0,r1,la,tn; };
struct Node2 { int sum,L1,l1,r1; };
struct Segement_Tree {
Node1 s[400010];
void push_up(ci& x,ci& lenl,ci& lenr) {
s[x].sum=s[ls].sum+s[rs].sum;
s[x].L0=max(max(s[ls].L0,s[rs].L0),s[ls].r0+s[rs].l0);
s[x].L1=max(max(s[ls].L1,s[rs].L1),s[ls].r1+s[rs].l1);
s[x].l0=s[ls].l0+s[rs].l0*(s[ls].l0==lenl);
s[x].l1=s[ls].l1+s[rs].l1*(s[ls].l1==lenl);
s[x].r0=s[rs].r0+s[ls].r0*(s[rs].r0==lenr);
s[x].r1=s[rs].r1+s[ls].r1*(s[rs].r1==lenr);
}
void push_down(ci& x,ci& lenl,ci& lenr) {
if (s[x].la^-1) {
s[ls].sum=s[ls].L1=s[ls].l1=s[ls].r1=s[x].la*lenl;
s[rs].sum=s[rs].L1=s[rs].l1=s[rs].r1=s[x].la*lenr;
s[ls].L0=s[ls].l0=s[ls].r0=(s[x].la^1)*lenl;
s[rs].L0=s[rs].l0=s[rs].r0=(s[x].la^1)*lenr;
s[ls].la=s[rs].la=s[x].la; s[x].la = -1; s[ls].tn=s[rs].tn=0;
}
if (s[x].tn) {
s[ls].sum=lenl-s[ls].sum; s[rs].sum=lenr-s[rs].sum;
swap(s[ls].L0,s[ls].L1); swap(s[rs].L0,s[rs].L1);
swap(s[ls].l0,s[ls].l1); swap(s[rs].l0,s[rs].l1);
swap(s[ls].r0,s[ls].r1); swap(s[rs].r0,s[rs].r1);
s[ls].tn^=1; s[rs].tn^=1; s[x].la=-1; s[x].tn=0;
}
}
void build(ci& x,ci& l,ci& r) {
s[x].la=-1;
if (l==r) { scanf("%d",&t); s[x]=(Node1){t,t^1,t,t^1,t,t^1,t,-1,0}; return; }
build(ls,l,mid); build(rs,mid+1,r);
push_up(x,mid-l+1,r-mid);
}
void change(ci& x,ci& L,ci& R,ci& l,ci& r,ci& k) {
if (L<=l&&r<=R) {
int x1=k*(r-l+1),x2=(k^1)*(r-l+1);
s[x] = (Node1){x1,x2,x1,x2,x1,x2,x1,k,0};
return;
}
if (R<l||r<L) return;
push_down(x,mid-l+1,r-mid);
change(ls,L,R,l,mid,k); change(rs,L,R,mid+1,r,k);
push_up(x,mid-l+1,r-mid);
}
void turn(ci& x,ci& L,ci& R,ci& l,ci& r) {
if (L<=l&&r<=R) {
s[x].sum=r-l+1-s[x].sum; s[x].tn^=1;
swap(s[x].L0,s[x].L1); swap(s[x].l0,s[x].l1); swap(s[x].r0,s[x].r1);
return;
}
if (R<l||r<L) return;
push_down(x,mid-l+1,r-mid);
turn(ls,L,R,l,mid); turn(rs,L,R,mid+1,r);
push_up(x,mid-l+1,r-mid);
}
Node2 query(ci& x,ci& L,ci& R,ci& l,ci& r) {
if (L<=l&&r<=R) return (Node2){s[x].sum,s[x].L1,s[x].l1,s[x].r1};
if (R<l||r<L) return (Node2){0,0,0,0};
push_down(x,mid-l+1,r-mid);
Node2 s,s1=query(ls,L,R,l,mid),s2=query(rs,L,R,mid+1,r);
s.sum=s1.sum+s2.sum;
s.L1=max(max(s1.L1,s2.L1),s1.r1+s2.l1);
s.l1=s1.l1+s2.l1*(s1.l1==mid-l+1);
s.r1=s2.r1+s1.r1*(s2.r1==r-mid);
return s;
}
}sgt;
int main() {
scanf("%d%d",&n,&m);
sgt.build(1,1,n);
for (register int i=1; i<=m; ++i) {
scanf("%d%d%d",&opt,&a,&b); ++a; ++b;
if (opt==4) printf("%d\n",sgt.query(1,a,b,1,n).L1);
else if (opt==3) printf("%d\n",sgt.query(1,a,b,1,n).sum);
else if (opt==2) sgt.turn(1,a,b,1,n);
else sgt.change(1,a,b,1,n,opt);
}
}