0~3操作都是很裸的线段树操作,主要是这个4操作有点复杂。
定义一下数组的含义:
l0/1:从左端点开始连续的0/1的个数
r0/1:从右端点开始连续的0/1的个数
m0/1:连续的最长的0/1串
sum:1的个数
在update操作的时候,l和r数组的更新都可以直接由孩子得到,但要注意一类情况:比如左孩子全部是1,那么就可以把右孩子左端的1连起来。m数组由两种情况得到,一种是孩子的m数组,还有一种是左孩子的右端和右孩子的左端拼在一起得到。
0~2操作都只要打上lazy-tag标记即可,注意的是如果打上了tag0/1标记,其他两个标记要被清除掉。
3操作只要查询sum即可。
对于4操作,我一开始想的就是直接统计,后来发现很难处理区间合并的情况。后来网上找了一种写法,就是把有贡献的区间一层一层合并存在一个新的节点里面,这样只要update一下就可以得到m数组了,非常神奇的写法。
#include<cmath>
#include<cstdio>
#include<vector>
#include<queue>
#include<cstring>
#include<iomanip>
#include<stdlib.h>
#include<iostream>
#include<algorithm>
#define inf 1000000000
#define mod 1000000007
#define N 400005
#define fo(i,a,b) for(i=a;i<=b;i++)
#define fd(i,a,b) for(i=a;i>=b;i--)
using namespace std;
int lson[N],rson[N],l0[N],l1[N],r0[N],r1[N],len[N],m0[N],m1[N],sum[N];
int ll[N],rr[N],tag0[N],tag1[N],tagrev[N],a[N];
int tot,total,n,m,i,root,opt,l,r;
void update(int x)
{
int ls = lson[x] , rs = rson[x];
l0[x] = l0[ls]; l1[x] = l1[ls];
r0[x] = r0[rs]; r1[x] = r1[rs];
if (l0[ls] == len[ls]) l0[x] = l0[ls] + l0[rs];
if (l1[ls] == len[ls]) l1[x] = l1[ls] + l1[rs];
if (r0[rs] == len[rs]) r0[x] = r0[ls] + r0[rs];
if (r1[rs] == len[rs]) r1[x] = r1[ls] + r1[rs];
m0[x] = max(r0[ls]+l0[rs],max(m0[ls],m0[rs]));
m1[x] = max(r1[ls]+l1[rs],max(m1[ls],m1[rs]));
sum[x] = sum[ls] + sum[rs];
}
void build(int &rt,int l,int r)
{
rt = ++tot;
ll[rt] = l; rr[rt] = r; len[rt] = r - l + 1;
tag0[rt] = tag1[rt] = tagrev[rt] = 0;
if (l == r)
{
lson[rt] = rson[rt] = 0;
l1[rt] = r1[rt] = m1[rt] = sum[rt] = (a[l] == 1);
l0[rt] = r0[rt] = m0[rt] = (a[l] == 0);
return;
}
int mid = (l + r) >> 1;
build(lson[rt],l,mid); build(rson[rt],mid+1,r);
update(rt);
}
void mark0(int x)
{
tag0[x] = 1; tag1[x] = tagrev[x] = 0;
l0[x] = r0[x] = m0[x] = len[x];
l1[x] = r1[x] = m1[x] = sum[x] = 0;
}
void mark1(int x)
{
tag1[x] = 1; tag0[x] = tagrev[x] = 0;
l0[x] = r0[x] = m0[x] = 0;
l1[x] = r1[x] = m1[x] = sum[x] = len[x];
}
void markrev(int x)
{
tagrev[x] ^= 1;
swap(l0[x],l1[x]); swap(r0[x],r1[x]);
swap(m0[x],m1[x]); sum[x] = len[x] - sum[x];
}
void pushdown(int x)
{
if (tag0[x]) {mark0(lson[x]); mark0(rson[x]); tag0[x] = 0;}
if (tag1[x]) {mark1(lson[x]); mark1(rson[x]); tag1[x] = 0;}
if (tagrev[x]) {markrev(lson[x]); markrev(rson[x]); tagrev[x] = 0;}
}
void change(int rt,int l,int r,int opt)
{
if (ll[rt] > r || rr[rt] < l) return;
pushdown(rt);
if (l <= ll[rt] && rr[rt] <= r)
{
if (opt == 0) mark0(rt); if (opt == 1) mark1(rt);
if (opt == 2) markrev(rt); return;
}
change(lson[rt],l,r,opt); change(rson[rt],l,r,opt);
update(rt);
}
int query1(int rt,int l,int r)
{
if (ll[rt] > r || rr[rt] < l) return 0;
pushdown(rt);
if (l <= ll[rt] && rr[rt] <= r) return sum[rt];
return query1(lson[rt],l,r) + query1(rson[rt],l,r);
}
int query2(int rt,int l,int r)
{
pushdown(rt);
if (l == ll[rt] && r == rr[rt]) return rt;
int mid = (ll[rt] + rr[rt]) >> 1;
if (r <= mid) return query2(lson[rt],l,r);
if (l > mid) return query2(rson[rt],l,r);
int ans = ++total;
lson[ans] = query2(lson[rt],l,mid); rson[ans] = query2(rson[rt],mid+1,r);
update(ans);
return ans;
}
int main()
{
scanf("%d%d",&n,&m);
fo(i,1,n) scanf("%d",&a[i]);
build(root,1,n);
fo(i,1,m)
{
scanf("%d%d%d",&opt,&l,&r); l++; r++;
if (opt <= 2) {change(root,l,r,opt); continue;}
if (opt == 3) {printf("%d\n",query1(root,l,r)); continue;}
if (opt == 4) {total = tot; printf("%d\n",m1[query2(root,l,r)]); continue;}
}
return 0;
}