一眼线段树代码题,交上去一直WA,找半天bug原来是区间合并写错了!!
#pragma GCC optimize(2)
#include<bits/stdc++.h>
using namespace std;
const int maxn = 1e5+5;
struct node{
int l, r;
bool set[2], rflag;
int sum, ls[2], rs[2], ms[2];
int mid() {return (l+r)/2;}
int cover() {return r-l+1;}
}tr[maxn<<2];
void pushup(int rt) {
tr[rt].sum = tr[rt<<1].sum + tr[rt<<1|1].sum;
for(int i = 0; i < 2; i++) {
tr[rt].ls[i] = tr[rt<<1].ls[i];
if(tr[rt<<1].ls[i] == tr[rt<<1].cover()) tr[rt].ls[i] += tr[rt<<1|1].ls[i];
tr[rt].rs[i] = tr[rt<<1|1].rs[i];
if(tr[rt<<1|1].rs[i] == tr[rt<<1|1].cover()) tr[rt].rs[i] += tr[rt<<1].rs[i];
tr[rt].ms[i] = max(max(tr[rt<<1].ms[i], tr[rt<<1|1].ms[i]), tr[rt<<1].rs[i] + tr[rt<<1|1].ls[i]);
}
}
void build(int rt, int l, int r) {
tr[rt].l = l;tr[rt].r = r;tr[rt].set[0] = tr[rt].set[1] = tr[rt].rflag = 0;
if(l == r) {
int c;scanf("%d", &c);
tr[rt].ls[c] = tr[rt].rs[c] = tr[rt].ms[c] = 1;
tr[rt].ls[c^1] = tr[rt].rs[c^1] = tr[rt].ms[c^1] = 0;
tr[rt].sum = tr[rt].ms[1];
return;
}
int mid = tr[rt].mid();
build(rt<<1, l, mid);
build(rt<<1|1, mid+1, r);
pushup(rt);
}
void pushdown(int rt) {
if(tr[rt].set[0] || tr[rt].set[1]) {
int c = tr[rt].set[0]?0:1;
tr[rt<<1].set[c] = 1;
tr[rt<<1].set[c^1] = tr[rt<<1].rflag = 0;
tr[rt<<1].ls[c] = tr[rt<<1].rs[c] = tr[rt<<1].ms[c] = tr[rt<<1].cover();
tr[rt<<1].ls[c^1] = tr[rt<<1].rs[c^1] = tr[rt<<1].ms[c^1] = 0;
tr[rt<<1].sum = tr[rt<<1].ms[1];
tr[rt<<1|1].set[c] = 1;
tr[rt<<1|1].set[c^1] = tr[rt<<1|1].rflag = 0;
tr[rt<<1|1].ls[c] = tr[rt<<1|1].rs[c] = tr[rt<<1|1].ms[c] = tr[rt<<1|1].cover();
tr[rt<<1|1].ls[c^1] = tr[rt<<1|1].rs[c^1] = tr[rt<<1|1].ms[c^1] = 0;
tr[rt<<1|1].sum = tr[rt<<1|1].ms[1];
tr[rt].set[0] = tr[rt].set[1] = 0;
}
if(tr[rt].rflag) {
swap(tr[rt<<1].ms[1], tr[rt<<1].ms[0]);
swap(tr[rt<<1].ls[1], tr[rt<<1].ls[0]);
swap(tr[rt<<1].rs[1], tr[rt<<1].rs[0]);
tr[rt<<1].sum = tr[rt<<1].cover()-tr[rt<<1].sum;
tr[rt<<1].rflag ^= 1;
swap(tr[rt<<1|1].ms[1], tr[rt<<1|1].ms[0]);
swap(tr[rt<<1|1].ls[1], tr[rt<<1|1].ls[0]);
swap(tr[rt<<1|1].rs[1], tr[rt<<1|1].rs[0]);
tr[rt<<1|1].sum = tr[rt<<1|1].cover()-tr[rt<<1|1].sum;
tr[rt<<1|1].rflag ^= 1;
tr[rt].rflag = 0;
}
}
void Set(int rt, int l, int r, int c) {
if(tr[rt].l == l && tr[rt].r == r) {
tr[rt].set[c] = 1;
tr[rt].set[c^1] = tr[rt].rflag = 0;
tr[rt].ls[c] = tr[rt].rs[c] = tr[rt].ms[c] = tr[rt].cover();
tr[rt].ls[c^1] = tr[rt].rs[c^1] = tr[rt].ms[c^1] = 0;
tr[rt].sum = tr[rt].ms[1];
return;
}
pushdown(rt);
int mid = tr[rt].mid();
if(r <= mid) Set(rt<<1, l, r, c);
else if(l > mid) Set(rt<<1|1, l, r, c);
else {
Set(rt<<1, l, mid, c);
Set(rt<<1|1, mid+1, r, c);
}
pushup(rt);
}
void Reverse(int rt, int l, int r) {
if(tr[rt].l == l && tr[rt].r == r) {
swap(tr[rt].ms[1], tr[rt].ms[0]);
swap(tr[rt].ls[1], tr[rt].ls[0]);
swap(tr[rt].rs[1], tr[rt].rs[0]);
tr[rt].sum = tr[rt].cover()-tr[rt].sum;
tr[rt].rflag ^= 1;
return;
}
pushdown(rt);
int mid = tr[rt].mid();
if(r <= mid) Reverse(rt<<1, l, r);
else if(l > mid) Reverse(rt<<1|1, l, r);
else {
Reverse(rt<<1, l, mid);
Reverse(rt<<1|1, mid+1, r);
}
pushup(rt);
}
int Querysum(int rt, int l, int r) {
if(tr[rt].l == l && tr[rt].r == r) return tr[rt].sum;
pushdown(rt);
int mid = tr[rt].mid();
if(r <= mid) return Querysum(rt<<1, l, r);
else if(l > mid) return Querysum(rt<<1|1, l, r);
else return Querysum(rt<<1, l, mid) + Querysum(rt<<1|1, mid+1, r);
}
node Queryconti(int rt, int l, int r) {
if(tr[rt].l == l && tr[rt].r == r) return tr[rt];
pushdown(rt);
int mid = tr[rt].mid();
if(r <= mid) return Queryconti(rt<<1, l, r);
else if(l > mid) return Queryconti(rt<<1|1, l, r);
else {
node tl = Queryconti(rt<<1, l, mid), tr = Queryconti(rt<<1|1, mid+1, r), res;
res.l = l;res.r = r;
res.ls[1] = tl.ls[1];
if(tl.ls[1] == tl.cover()) res.ls[1] += tr.ls[1];
res.rs[1] = tr.rs[1];
if(tr.rs[1] == tr.cover()) res.rs[1] += tl.rs[1];
res.ms[1] = max(max(tl.ms[1], tr.ms[1]), tl.rs[1]+tr.ls[1]);
return res;
}
}
int main() {
int n, m;
scanf("%d%d", &n, &m);
build(1, 1, n);
for(int i = 1; i <= m; i++) {
int op, x, y;scanf("%d%d%d", &op, &x, &y);
x++;y++;
if(op == 0) Set(1, x, y, 0);
else if(op == 1) Set(1, x, y, 1);
else if(op == 2) Reverse(1, x, y);
else if(op == 3) printf("%d\n", Querysum(1, x, y));
else if(op == 4) printf("%d\n", Queryconti(1, x, y).ms[1]);
}
return 0;
}