原题链接: http://acm.hdu.edu.cn/showproblem.php?pid=3397
一:分析
题目意思很简单,不做详述。
二:AC代码
#define _CRT_SECURE_NO_DEPRECATE
#include<iostream>
#include<cmath>
#include<algorithm>
using namespace std;
struct Node
{
int left, right, mid;
int l0, l1;
int r0, r1;
int m0, m1;//区间连续0的个数,区间连续1的个数
int sum1;//区间1的个数
int len;
int XOR;//标记
int COVER;//标记
}p[100005 * 4];
int a[100005];
void pushUp(int n)
{
p[n].l0 = p[n * 2].l0 + (p[n * 2].l0 == p[n * 2].len ? p[n * 2 + 1].l0 : 0);
p[n].l1 = p[n * 2].l1 + (p[n * 2].l1 == p[n * 2].len ? p[n * 2 + 1].l1 : 0);
p[n].r0 = p[n * 2 + 1].r0 + (p[n * 2 + 1].r0 == p[n * 2 + 1].len ? p[n * 2].r0 : 0);
p[n].r1 = p[n * 2 + 1].r1 + (p[n * 2 + 1].r1 == p[n * 2 + 1].len ? p[n * 2].r1 : 0);
p[n].m0 = max(max(p[n * 2].m0, p[n * 2 + 1].m0), p[n * 2].r0 + p[n * 2 + 1].l0);
p[n].m1 = max(max(p[n * 2].m1, p[n * 2 + 1].m1), p[n * 2].r1 + p[n * 2 + 1].l1);
p[n].sum1 = p[n * 2].sum1 + p[n * 2 + 1].sum1;
}
void pushDown(int n)
{
if (p[n].COVER != -1)
{
p[n * 2].COVER = p[n * 2 + 1].COVER = p[n].COVER;//标记传递
p[n].XOR = p[n * 2].XOR = p[n * 2 + 1].XOR = 0;//
p[n * 2].l0 = p[n * 2].r0 = p[n * 2].m0 = (p[n].COVER == 0 ? p[n * 2].len : 0);
p[n * 2].l1 = p[n * 2].r1 = p[n * 2].m1 = p[n * 2].sum1 = (p[n].COVER == 1 ? p[n * 2].len : 0);
p[n * 2 + 1].sum1 = p[n * 2 + 1].l1 = p[n * 2 + 1].r1 = p[n * 2 + 1].m1 = (p[n].COVER == 0 ? 0 : p[n * 2 + 1].len);
p[n * 2 + 1].l0 = p[n * 2 + 1].r0 = p[n * 2 + 1].m0 = (p[n].COVER == 0 ? p[n * 2 + 1].len : 0);
p[n].COVER = -1;//恢复
}
if (p[n].XOR)
{
p[n * 2].XOR ^= 1;//标记传递
p[n * 2 + 1].XOR ^= 1;
swap(p[n * 2].l0, p[n * 2].l1);
swap(p[n * 2].r0, p[n * 2].r1);
swap(p[n * 2].m0, p[n * 2].m1);
p[n * 2].sum1 = p[n * 2].len - p[n * 2].sum1;
swap(p[n * 2 + 1].l0, p[n * 2 + 1].l1);
swap(p[n * 2 + 1].r0, p[n * 2 + 1].r1);
swap(p[n * 2 + 1].m0, p[n * 2 + 1].m1);
p[n * 2 + 1].sum1 = p[n * 2 + 1].len - p[n * 2 + 1].sum1;
p[n].XOR = 0;//标记恢复
}
}
void build(int n, int l, int r)
{
p[n].left = l;
p[n].right = r;
p[n].mid = (l + r) / 2;
p[n].len = r - l + 1;
p[n].XOR = 0;
p[n].COVER = -1;
if (l == r)
{
p[n].l0 = p[n].r0 = p[n].m0 = (a[l] == 0);
p[n].l1 = p[n].r1 = p[n].m1 = p[n].sum1 = a[l];
return;
}
build(n * 2, l, p[n].mid);
build(n * 2 + 1, p[n].mid + 1, r);
pushUp(n);
}
void update(int n, int l, int r, int op)
{
if (p[n].left == l&&p[n].right == r)
{
if (op < 2)//0或者1覆盖区间
{
p[n].COVER = op;//标记赋值
p[n].l0 = p[n].r0 = p[n].m0 = (op == 0 ? p[n].len : 0);
p[n].sum1 = p[n].l1 = p[n].r1 = p[n].m1 = (op == 1 ? p[n].len : 0);
}
else
{
p[n].XOR ^= 1;//标记赋值
swap(p[n].l0, p[n].l1);
swap(p[n].r0, p[n].r1);
swap(p[n].m0, p[n].m1);
p[n].sum1 = p[n].len - p[n].sum1;
}
return;
}
pushDown(n);
if (r <= p[n].mid)
update(n * 2, l, r, op);
else if (l > p[n].mid)
update(n * 2 + 1, l, r, op);
else
{
update(n * 2, l, p[n].mid, op);
update(n * 2 + 1, p[n].mid + 1, r, op);
}
pushUp(n);
}
/* 3 a b output the number of '1's in [a, b] */
int query3(int n, int l, int r)
{
if (p[n].left == l&&p[n].right == r)
return p[n].sum1;
pushDown(n);
if (r <= p[n].mid)
return query3(n * 2, l, r);
else if (l > p[n].mid)
return query3(n * 2 + 1, l, r);
else
return query3(n * 2, l, p[n].mid) + query3(n * 2 + 1, p[n].mid + 1, r);
}
/* 4 a b output the length of the longest continuous '1' string in [a , b] */
int query4(int n, int l, int r)
{
if (p[n].left == l&&p[n].right == r)
return p[n].m1;
pushDown(n);
if (r <= p[n].mid)
return query4(n * 2, l, r);
else if (l > p[n].mid)
return query4(n * 2 + 1, l, r);
else
{
int lSum = query4(n * 2, l, p[n].mid);
int rSum = query4(n * 2 + 1, p[n].mid + 1, r);
int mSum = min(p[n * 2].r1, p[n].mid - l + 1) + min(p[n * 2 + 1].l1, r - p[n].mid);
return max(max(lSum, rSum), mSum);
}
}
int main()
{
int t;
int n, m;
int x, l, r;
scanf("%d", &t);
while (t--)
{
scanf("%d%d", &n, &m);
for (int i = 0; i < n; i++)
scanf("%d", &a[i]);
build(1, 0, n - 1);
while (m--)
{
scanf("%d%d%d", &x, &l, &r);
if (x < 3)
update(1, l, r, x);
else if (x == 3)
printf("%d\n", query3(1, l, r));
else
printf("%d\n", query4(1, l, r));
}
}
return 0;
}