题意:
lxhgww最近收到了一个01序列,序列里面包含了n个数,这些数要么是0,要么是1,现在对于这个序列有五种变换操作和询问操作:
0 a b 把[a, b]区间内的所有数全变成0
1 a b 把[a, b]区间内的所有数全变成1
2 a b 把[a,b]区间内的所有数全部取反,也就是说把所有的0变成1,把所有的1变成0
3 a b 询问[a, b]区间内总共有多少个1
4 a b 询问[a, b]区间内最多有多少个连续的1
对于每一种询问操作,lxhgww都需要给出回答,聪明的程序员们,你们能帮助他吗?
题解:
线段树版本:
每个结点要维护:
- lazy标记该区间是否整个被赋值
- con标记该区间是否被反转
- sum标记该区间1的个数
- mx0/1标记该区间最多的连续0/1个数
- lmx0/1标记该区间从左端点开始最多的连续的0/1个数
- rmx0/1标记该区间从右端点开始最多的连续的0/1个数
然后维护起来细节比较多,调了一下午才A。
珂朵莉树版本:
每个询问暴力查询,赋值的时候合并区间。思维量少代码短,柯学万岁(≧▽≦)/
ac代码(线段树):
#include<bits/stdc++.h>
#define mid ((l+r)>>1)
#define lson rt<<1, l, mid
#define rson rt<<1|1, mid + 1, r
using namespace std;
const int maxn = 1e5 + 50;
int lz[maxn<<2], con[maxn<<2], sum[maxn<<2], mx[maxn<<2][2], lmx[maxn<<2][2], rmx[maxn<<2][2];
int a[maxn];
int n, m;
void up(int rt, int l, int r)
{
sum[rt] = sum[rt<<1] + sum[rt<<1|1];
for(int i = 0; i < 2; ++i){
lmx[rt][i] = lmx[rt<<1][i];
if(sum[rt<<1] == i * (mid - l + 1)){
lmx[rt][i] += lmx[rt<<1|1][i];
}
rmx[rt][i] = rmx[rt<<1|1][i];
if(sum[rt<<1|1] == i * (r - mid)){
rmx[rt][i] += rmx[rt<<1][i];
}
mx[rt][i] = rmx[rt<<1][i] + lmx[rt<<1|1][i];
mx[rt][i] = max(mx[rt][i], mx[rt<<1][i]);
mx[rt][i] = max(mx[rt][i], mx[rt<<1|1][i]);
}
return;
}
void build(int rt, int l, int r)
{
lz[rt] = -1;
con[rt] = 0;
if(l == r){
int v;
scanf("%d", &v);
sum[rt] = v;
mx[rt][v] = 1;
lmx[rt][v] = 1;
rmx[rt][v] = 1;
return;
}
build(lson);
build(rson);
up(rt, l, r);
return;
}
void down(int rt, int l, int r)
{
if(lz[rt] != -1)
{
lz[rt<<1] = lz[rt<<1|1] = lz[rt];
int v = lz[rt];
lz[rt] = -1;
con[rt] = 0;
con[rt<<1] = con[rt<<1|1] = 0;
sum[rt<<1] = v * (mid - l + 1);
sum[rt<<1|1] = v * (r - mid);
mx[rt<<1][v] = lmx[rt<<1][v] = rmx[rt<<1][v] = mid - l + 1;
mx[rt<<1][v^1] = lmx[rt<<1][v^1] = rmx[rt<<1][v^1] = 0;
mx[rt<<1|1][v] = lmx[rt<<1|1][v] = rmx[rt<<1|1][v] = r - mid;
mx[rt<<1|1][v^1] = lmx[rt<<1|1][v^1] = rmx[rt<<1|1][v^1] = 0;
return;
}
if(con[rt]){
con[rt<<1] ^= 1;
con[rt<<1|1] ^= 1;
if(lz[rt<<1] != -1){
lz[rt<<1] ^= 1;
}
if(lz[rt<<1|1] != -1){
lz[rt<<1|1] ^= 1;
}
sum[rt<<1] = mid - l + 1 - sum[rt<<1];
sum[rt<<1|1] = r - mid - sum[rt<<1|1];
for(int i =(rt<<1); i <= (rt<<1|1); ++i){
swap(mx[i][0], mx[i][1]);
swap(lmx[i][0], lmx[i][1]);
swap(rmx[i][0], rmx[i][1]);
}
con[rt] = 0;
}
return;
}
void update(int rt, int l, int r, int L, int R, int op)
{
//cout<<"l:"<<l<<" r:"<<r<<" sum:"<<sum[rt]<<endl;
if(L <= l && r <= R)
{
if(op <= 1){
lz[rt] = op;
con[rt] = 0;
sum[rt] = op * (r - l + 1);
mx[rt][op] = lmx[rt][op] = rmx[rt][op] = r - l + 1;
mx[rt][op^1] = lmx[rt][op^1] = rmx[rt][op^1] = 0;
return;
}
else{
con[rt] ^= 1;
if(lz[rt] != -1) {
lz[rt] ^= 1;
}
sum[rt] = r - l + 1 - sum[rt];
swap(mx[rt][0], mx[rt][1]);
swap(lmx[rt][0], lmx[rt][1]);
swap(rmx[rt][0], rmx[rt][1]);
return;
}
}
down(rt, l, r);
if(L <= mid) update(lson ,L, R, op);
if(R > mid) update(rson, L, R, op);
up(rt, l, r);
return;
}
int query(int rt, int l, int r, int L, int R, int op)
{
if(L <= l && r <= R){
if(op == 3) return sum[rt];
else return mx[rt][1];
}
down(rt, l, r);
int ans = 0;
if(L <= mid){
if(op == 3) ans += query(lson, L, R, op);
else ans = max(ans, query(lson, L, R, op));
}
if(R > mid){
if(op == 3) ans += query(rson, L, R, op);
else ans = max(ans, query(rson, L, R, op));
}
if(op == 4 && L <= mid && R > mid){
int tl = min(mid - L + 1, rmx[rt<<1][1]);
int tr = min(R - mid, lmx[rt<<1|1][1]);
ans = max(ans, tl + tr);
}
up(rt, l, r);
return ans;
}
int main()
{
scanf("%d%d",&n, &m);
build(1, 0, n-1);
while(m--)
{
int op, l, r;
scanf("%d%d%d", &op, &l, &r);
if(op < 3) update(1, 0, n-1, l, r, op);
else printf("%d\n", query(1, 0, n-1, l, r, op));
}
}
/*
10 10
1 0 1 1 1 1 0 1 1 1
4 0 9
*/
ac代码(珂朵莉树):
#include<bits/stdc++.h>
using namespace std;
const int maxn = 1e5 + 50;
struct node{
int l, r;
mutable int val;
node(int a = 0, int b = 0, int c = 0):l(a), r(b), val(c){}
bool operator < (const node& x)const{return l < x.l;}
};
set<node> s;
set<node>::iterator split(int x){
set<node>::iterator it = s.lower_bound(node(x));
if(it!=s.end() && it->l == x) return it;
it--;
int l = it->l, r = it->r, val = it->val;
s.erase(it);
s.insert(node(l, x-1, val));
return s.insert(node(x, r, val)).first;
}
void update(int l, int r, int val){
set<node>::iterator rit = split(r+1), lit = split(l);
s.erase(lit, rit);
s.insert(node(l, r, val));
}
void rev(int l, int r){
set<node>::iterator rit = split(r+1), lit = split(l);
while(lit!=rit) lit->val ^= 1, lit++;
}
int sum(int l, int r){
set<node>::iterator rit = split(r+1), lit = split(l);
int ans = 0;
while(lit!=rit) {
ans += lit->val*(lit->r - lit->l + 1), lit++;
}
return ans;
}
int query(int l, int r){
int mx = 0;
int res = 0;
set<node>::iterator rit = split(r+1), lit = split(l);
while(lit!=rit){
if(lit->val == 0) res = 0;
else{
res += lit->val*(lit->r - lit->l + 1);
mx = max(mx, res);
}
lit++;
}
return mx;
}
int n, m;
int a[maxn];
int main()
{
scanf("%d%d", &n, &m);
int pre = 0;
for(int i = 0; i < n; ++i) {
scanf("%d", &a[i]);
if(a[i] == a[pre]) continue;
else {
s.insert(node(pre, i-1, a[pre]));
pre = i;
}
}
s.insert(node(pre, n-1, a[n-1]));
while(m--){
int op, l, r;
scanf("%d%d%d", &op, &l, &r);
if(op < 2) update(l, r, op);
else if(op == 2) rev(l ,r);
else if(op == 3){
printf("%d\n", sum(l, r));
}
else printf("%d\n", query(l ,r));
}
}