给定n个0/1,m个操作,操作1是区间清零,操作2是区间置1,操作3区间0/1互换,操作4查询有多少个1,操作5查询最多有多少个连续的1。
一看就是线段树,维护一段区间的前趋后继和最多连续1以及sum即可。
debug了好久发现自己pushdown写错了一句话。心痛,本来一小时的事情(敲代码还是慢啊) 我觉得我写的比网上的题解优雅多了
#include<iostream>
#include<string>
#include<string.h>
#include<stdlib.h>
#include<cstdio>
#include<stdio.h>
#include<set>
#include<map>
#include<deque>
#include<stack>
#include<vector>
#include<queue>
#include<algorithm>
#include<cstring>
#include<cmath>
#include<sstream>
#include<istream>
#include<ostream>
#include<sstream>
using namespace std;
#define lc n << 1
#define rc n << 1 | 1
const int maxn = 100010;
struct segtree {
int l, r;
int pre[2], suf[2], sum, mx[2];
bool txor;
int lazy;
}t[maxn << 2];
inline int getlen(int n) { return t[n].r - t[n].l + 1; }
inline void pushup(int n)
{
t[n].sum = t[lc].sum + t[rc].sum;
for (int i = 0;i < 2;i++)
{
t[n].pre[i] = t[lc].pre[i];
if (t[n].pre[i] == getlen(lc)) { t[n].pre[i] += t[rc].pre[i]; }
t[n].suf[i] = t[rc].suf[i];
if (t[n].suf[i] == getlen(rc)) { t[n].suf[i] += t[lc].suf[i]; }
t[n].mx[i] = max(t[lc].mx[i], t[rc].mx[i]);
t[n].mx[i] = max(t[n].mx[i], t[lc].suf[i] + t[rc].pre[i]);
}
}
inline void setstatus(int n, int cmd)
{
if (cmd == 0)
{
t[n].sum = t[n].pre[1] = t[n].suf[1] = t[n].mx[1] = 0;
t[n].pre[0] = t[n].suf[0] = t[n].mx[0] = getlen(n);
t[n].lazy = -1;t[n].txor = false;
}
else if (cmd == 1)
{
t[n].sum = t[n].pre[1] = t[n].suf[1] = t[n].mx[1] = getlen(n);
t[n].pre[0] = t[n].suf[0] = t[n].mx[0] = 0;
t[n].lazy = 1;t[n].txor = false;
}
else if (cmd == 2)
{
t[n].sum = getlen(n) - t[n].sum;
swap(t[n].pre[0], t[n].pre[1]);
swap(t[n].suf[0], t[n].suf[1]);
swap(t[n].mx[0], t[n].mx[1]);
if (t[n].lazy != 0)
t[n].lazy = -t[n].lazy;
else
t[n].txor = !t[n].txor;
}
}
inline void pushdown(int n)
{
if (t[n].lazy != 0)
{
t[lc].lazy = t[rc].lazy = t[n].lazy;
t[lc].txor = t[rc].txor = false;
int x = t[n].lazy == 1;
t[lc].sum = x * getlen(lc);
t[rc].sum = x * getlen(rc);
t[lc].pre[x] = t[lc].suf[x] = t[lc].mx[x] = getlen(lc);
t[rc].pre[x] = t[rc].suf[x] = t[rc].mx[x] = getlen(rc);
t[lc].pre[x ^ 1] = t[lc].suf[x ^ 1] = t[lc].mx[x ^ 1]
= t[rc].pre[x ^ 1] = t[rc].suf[x ^ 1] = t[rc].mx[x ^ 1] = 0;
t[n].lazy = 0;
}
else if (t[n].txor)
{
setstatus(lc, 2);
setstatus(rc, 2);
t[n].txor = false;
}
}
void build(int n, int l, int r)
{
t[n].l = l;t[n].r = r;
if (l == r)
{
scanf("%d", &t[n].sum);
t[n].pre[1] = t[n].suf[1] = t[n].sum;
t[n].pre[0] = t[n].suf[0] = !t[n].pre[1];
t[n].mx[1] = t[n].sum;
t[n].mx[0] = !t[n].mx[1];
return;
}
int mid = l + ((r - l) >> 1);
build(lc, l, mid);
build(rc, mid + 1, r);
pushup(n);
}
void update(int n, int l, int r, int cmd)
{
if (l > t[n].r || r < t[n].l)return;
if (l <= t[n].l && r >= t[n].r)
{
setstatus(n, cmd);
return;
}
pushdown(n);
update(lc, l, r, cmd);
update(rc, l, r, cmd);
pushup(n);
}
int qsum(int n, int l, int r)
{
if (l > t[n].r || r < t[n].l)return 0;
if (l <= t[n].l && r >= t[n].r)return t[n].sum;
pushdown(n);
return qsum(lc, l, r) + qsum(rc, l, r);
}
int qpre(int n, int l, int r)
{
if (l > t[n].r || r < t[n].l)return 0;
if (l <= t[n].l && r >= t[n].r)return t[n].pre[1];
pushdown(n);
if (r <= t[lc].r)return qpre(lc, l, r);
if (l >= t[rc].l)return qpre(rc, l, r);
int ret = qpre(lc, l, r);
if (ret == getlen(lc))
return ret + qpre(rc, l, r);
return ret;
}
int qsuf(int n, int l, int r)
{
if (l > t[n].r || r < t[n].l)return 0;
if (l <= t[n].l && r >= t[n].r)return t[n].suf[1];
pushdown(n);
if (l >= t[rc].l)return qsuf(rc, l, r);
if (r <= t[lc].r)return qsuf(lc, l, r);
int ret = qsuf(rc, l, r);
if (ret == getlen(rc))
return ret + qsuf(lc, l, r);
return ret;
}
int qmax(int n, int l, int r)
{
if (l > t[n].r || r < t[n].l)return 0;
if (l <= t[n].l && r >= t[n].r)return t[n].mx[1];
pushdown(n);
if (r <= t[lc].r)return qmax(lc, l, r);
if (l >= t[rc].l)return qmax(rc, l, r);
int ret = max(qmax(lc, l, r), qmax(rc, l, r));
return max(ret, qsuf(lc, l, r) + qpre(rc, l, r));
}
int main()
{
int n, m, c, l, r;
scanf("%d%d", &n, &m);
memset(t, 0, sizeof t);
build(1, 1, n);
while (m--)
{
scanf("%d%d%d", &c, &l, &r);
if (c < 3)update(1, l + 1, r + 1, c);
else if (c == 3)printf("%d\n", qsum(1, l + 1, r + 1));
else if (c == 4)printf("%d\n", qmax(1, l + 1, r + 1));
}
return 0;
}