线段树模板题
#include <cstdio>
#include <iostream>
using namespace std;
int n,m,f,u,v;
int l0[400010],r0[400010],m0[400010],l1[400010],r1[400010],m1[400010];
int l[400010],r[400010],lz[400010],sum[400010];
int read_int () {
char c = getchar();
int re = 0;
for(;c > '9' || c < '0'; c = getchar());
for(;c >= '0' && c <= '9';c = getchar())
re = re * 10 + c - '0';
return re;
}
void push_up (int M) {
int ls = M << 1;
int rs = ls | 1;
sum[M] = sum[ls] + sum[rs];
l0[M] = l0[ls];
if(l0[M] == r[ls] - l[ls] + 1)
l0[M] += l0[rs];
r0[M] = r0[rs];
if(r0[M] == r[rs] - l[rs] + 1)
r0[M] += r0[ls];
m0[M] = max(m0[ls],m0[rs]);
m0[M] = max(m0[M],l0[rs] + r0[ls]);
l1[M] = l1[ls];
if(l1[M] == r[ls] - l[ls] + 1)
l1[M] += l1[rs];
r1[M] = r1[rs];
if(r1[M] == r[rs] - l[rs] + 1)
r1[M] += r1[ls];
m1[M] = max(m1[ls],m1[rs]);
m1[M] = max(m1[M],l1[rs] + r1[ls]);
}
void build (int L,int R,int M) {
l[M] = L;
r[M] = R;
if(L == R) {
sum[M] = read_int();
if(sum[M])
l1[M] = r1[M] = m1[M] = 1;
else l0[M] = r0[M] = m0[M] = 1;
return;
}
int mid = (L + R) / 2;
build(L,mid,M << 1);
build(mid + 1,R,M << 1 | 1);
push_up(M);
}
void modi_0 (int M) {
l0[M] = r0[M] = m0[M] = r[M] - l[M] + 1;
sum[M] = l1[M] = r1[M] = m1[M] = 0;
lz[M] = 1;
}
void modi_1 (int M) {
l0[M] = r0[M] = m0[M] = 0;
sum[M] = l1[M] = r1[M] = m1[M] = r[M] - l[M] + 1;
lz[M] = 2;
}
void swp (int &a,int &b) {
int t = a;
a = b;
b = t;
}
void modi_2 (int M) {
swp(l0[M],l1[M]);
swp(r0[M],r1[M]);
swp(m0[M],m1[M]);
sum[M] = r[M] - l[M] + 1 - sum[M];
if(lz[M] == 1)
lz[M] = 2;
else {
if(lz[M] == 2)
lz[M] = 1;
else {
if(lz[M] == 3)
lz[M] = 0;
else lz[M] = 3;
}
}
}
void push_down (int M) {
if(!lz[M])
return;
if(lz[M] == 1) {
modi_0(M << 1);
modi_0(M << 1 | 1);
}
else {
if(lz[M] == 2) {
modi_1(M << 1);
modi_1(M << 1 | 1);
}
else {
modi_2(M << 1);
modi_2(M << 1 | 1);
}
}
lz[M] = 0;
}
void change (int f,int L,int R,int M) {
if(l[M] == L && r[M] == R) {
if(f == 0)
modi_0(M);
else
if(f == 1)
modi_1(M);
else modi_2(M);
return;
}
push_down(M);
int mid = (l[M] + r[M]) / 2;
if(L > mid)
change(f,L,R,M << 1 | 1);
else
if(R <= mid)
change(f,L,R,M << 1);
else {
change(f,L,mid,M << 1);
change(f,mid + 1,R,M << 1 | 1);
}
push_up(M);
}
int find_sum (int L,int R,int M) {
if(L == l[M] && R == r[M])
return sum[M];
push_down(M);
int mid = (l[M] + r[M]) / 2;
if(R <= mid)
return find_sum(L,R,M << 1);
else
if(L > mid)
return find_sum(L,R,M << 1 | 1);
else return find_sum(L,mid,M << 1) + find_sum(mid + 1,R,M << 1 | 1);
}
int find_1 (int L,int R,int &ll1,int &rr1,int M) {
if(L == l[M] && R == r[M]) {
ll1 = l1[M];
rr1 = r1[M];
return m1[M];
}
push_down(M);
int mid = (l[M] + r[M]) / 2;
if(L > mid)
return find_1(L,R,ll1,rr1,M << 1 | 1);
else
if(R <= mid)
return find_1(L,R,ll1,rr1,M << 1);
else {
int re,lll1,lrr1,rll1,rrr1,lm1,rm1;
lm1 = find_1(L,mid,lll1,lrr1,M << 1);
rm1 = find_1(mid + 1,R,rll1,rrr1,M << 1 | 1);
ll1 = lll1;
if(ll1 == mid - L + 1)
ll1 += rll1;
rr1 = rrr1;
if(rr1 == R - mid)
rr1 += lrr1;
re = max(lm1,rm1);
return re = max(re,lrr1 + rll1);
}
}
int main () {
n = read_int();
m = read_int();
build(1,n,1);
int ll1,rr1;
for(int i = 1;i <= m;++i) {
f = read_int();
u = read_int() + 1;
v = read_int() + 1;
if(f <= 2)
change(f,u,v,1);
else {
if(f == 3)
printf("%d\n",find_sum(u,v,1));
else printf("%d\n",find_1(u,v,ll1,rr1,1));
}
}
}