Description
Pupil 发现对于一个十进制数,无论怎么将其的数字重新排列,均不影响其是不是3 的倍数。他想研究对于二进制,是否也有类似的性质。于是他生成了一个长为n的二进制串,希望你对于这个二进制串的一个子区间,能求出其有多少位置不同的连续子串,满足在重新排列后(可包含前导 0)是一个3的倍数。两个位置不同的子区间指开始位置不同或结束位置不同。由于他想尝试尽量多的情况,他有时会修改串中的一个位置,并且会进行多次询问。
Input
从文件 binary.in 中读入数据。
输入第一行包含一个正整数 n,表示二进制数的长度。
之后一行 n 个空格隔开的整数,保证均是0或1,表示该二进制串。
之后一行一个整数 m,表示询问和修改的总次数。
之后m行每行为1 i,表示Pupil修改了串的第i个位置(0 变成 1 或 1 变成 0 ),或 2 l r,表示Pupil询问的子区间是[l,r]。
串的下标从 1 开始。
Output
输出到文件 binary.out 中。
对于每次询问,输出一行一个整数表示对应该询问的结果。
Sample Input
4
1 0 1 0
3
2 1 3
1 3
2 3 4
Sample Output
2
3
Data Constraint
对于 20% 的数据,1 ≤ n,m ≤ 100。
对于 50% 的数据,1 ≤ n,m ≤ 5000。
对于 100% 的数据,1 ≤ n,m ≤ 100000,l ≤ r。
Hint
对于第一个询问,区间 [2,2] 只有数字 0,是 3 的倍数,区间 [1,3] 可以重排成011(2) = 3(10) ,是3的倍数,其他区间均不能重排成 3 的倍数。对于第二个询问,全部三个区间均能重排成 3 的倍数(注意 00 也是合法的)。
Solution
我们考虑一个二进制数 a 如何变成十进制数:
就是等于
a1∗20+a2∗21+a3∗22+…+an∗2n−1 。我们知道一个十进制数如果各位加起来是3的倍数那么它就是3的倍数。
而像 1,2,4,8,16,32…… 分别距离3的倍数是相差 −1,1,−1,1,−1…… 。
于是我们得到这样的结论:先设区间长度为 L ,区间中1的个数为
S 。如果 S 是偶数就能两两配对(1配-1),一定能行。
而
S 是奇数的话, S=1 (只有一个1)肯定不行,而 S>1 可以这样:前面的1先两两配对,剩下三个1,用三个1或-1配对即可,这需要满足: L≥S+2
我们用线段树维护即可, O(N log N) 。
什么?这太难维护?确实,这太难维护了。
我们考虑正难则反,用线段树维护区间中不合法的子区间个数,用总区间个数减去就是答案了。
两个区间合并时计算跨两个区间的子区间个数,这需要维护许多信息。
重新看看我们到底要算什么:(设 C0,C1 分别表示区间中 0 和
1 的个数)- C1 是奇数,且 C0<2 的区间个数;
- C1=1 ,且 C0≥2 的区间个数( C0<2 的之前在第一种中已经算过了)。
于是我们维护几个信息(0同理,即反过来):
- 区间前后缀 1 的长度;①
- 区间前后缀一串1(可以没有),接个0(记着位置),再接一串1(也可以没有)的长度②。
什么 C0=0 就是 ① 和 ① 接起来,什么 C0=1 就是 ① 和 ② 接起来……
那么我们就可以直接分类讨论算了,细节和码量都很多。
Code
#include<cstdio>
#include<cctype>
using namespace std;
typedef long long LL;
const int N=1e5+5;
struct data
{
LL sum;
int l,r;
int l1,r1;//11111
int l2,r2,pl2,pr2;//11011
int l3,r3,pl3,pr3;//00100
int l4,r4;//00000
}f[N<<2];
int qx,qy;
int a[N];
inline int read()
{
int X=0,w=0; char ch=0;
while(!isdigit(ch)) w|=ch=='-',ch=getchar();
while(isdigit(ch)) X=(X<<1)+(X<<3)+(ch^48),ch=getchar();
return w?-X:X;
}
inline data merge(data f1,data f2)
{
data g;
g.l=f1.l,g.r=f2.r;
g.sum=f1.sum+f2.sum;
//11111
g.l1=f1.l1,g.r1=f2.r1;
if(f1.l1==f1.r-f1.l+1) g.l1+=f2.l1;
if(f2.r1==f2.r-f2.l+1) g.r1+=f1.r1;
g.sum+=(LL)(f1.r1+1)/2*(f2.l1/2)+(LL)(f1.r1/2)*((f2.l1+1)/2);
//11011
if(f1.l1==f1.r-f1.l+1)
{
g.l2=f1.l1+f2.l2;
g.pl2=f2.pl2?f1.l1+f2.pl2:0;
}else
{
g.l2=f1.l2;
g.pl2=f1.pl2;
if(f1.l2==f1.r-f1.l+1) g.l2+=f2.l1;
}
if(f2.r1==f2.r-f2.l+1)
{
g.r2=f2.r1+f1.r2;
g.pr2=f1.pr2?f2.r1+f1.pr2:0;
}else
{
g.r2=f2.r2;
g.pr2=f2.pr2;
if(f2.r2==f2.r-f2.l+1) g.r2+=f1.r1;
}
if(f2.pl2)
{
int ll=f1.r1,rr=f2.l2-f2.pl2+1,odd=(f2.pl2-1)&1;
g.sum+=(LL)(ll+(odd^1))/2;
if(rr>1) g.sum+=(LL)rr/2*((ll+odd)/2)+(LL)(rr-1)/2*((ll+(odd^1))/2);
}
if(f1.pr2)
{
int ll=f1.r2-f1.pr2+1,rr=f2.l1,odd=(f1.pr2-1)&1;
g.sum+=(LL)(rr+(odd^1))/2;
if(ll>1) g.sum+=(LL)ll/2*((rr+odd)/2)+(LL)(ll-1)/2*((rr+(odd^1))/2);
}
//00000
g.l4=f1.l4,g.r4=f2.r4;
if(f1.l4==f1.r-f1.l+1) g.l4+=f2.l4;
if(f2.r4==f2.r-f2.l+1) g.r4+=f1.r4;
//00100
if(f1.l4==f1.r-f1.l+1)
{
g.l3=f1.l4+f2.l3;
g.pl3=f2.pl3?f1.l4+f2.pl3:0;
}else
{
g.l3=f1.l3;
g.pl3=f1.pl3;
if(f1.l3==f1.r-f1.l+1) g.l3+=f2.l4;
}
if(f2.r4==f2.r-f2.l+1)
{
g.r3=f2.r4+f1.r3;
g.pr3=f1.pr3?f2.r4+f1.pr3:0;
}else
{
g.r3=f2.r3;
g.pr3=f2.pr3;
if(f2.r3==f2.r-f2.l+1) g.r3+=f1.r4;
}
if(f2.pl3 && f1.pr3!=1)
{
int ll=f1.r4,rr=f2.l3-f2.pl3+1;
g.sum+=(LL)ll*rr;
if(!f2.l4 && f1.r4) g.sum--;
}
if(f1.pr3 && f2.pl3!=1)
{
int ll=f1.r3-f1.pr3+1,rr=f2.l4;
g.sum+=(LL)ll*rr;
if(!f1.r4 && f2.l4) g.sum--;
}
return g;
}
void make(int v,int l,int r)
{
f[v].l=l,f[v].r=r;
if(l==r)
{
if(a[l])
{
f[v].l1=f[v].r1=f[v].sum=1;
f[v].pl3=f[v].pr3=1;
}else
{
f[v].l4=f[v].r4=1;
f[v].pl2=f[v].pr2=1;
}
f[v].l2=f[v].r2=1;
f[v].l3=f[v].r3=1;
return;
}
int mid=l+r>>1;
make(v<<1,l,mid);
make(v<<1|1,mid+1,r);
f[v]=merge(f[v<<1],f[v<<1|1]);
}
void change(int v,int l,int r)
{
if(l==r)
{
if(f[v].sum)
{
f[v].l1=f[v].r1=f[v].sum=0;
f[v].pl2=f[v].pr2=1;
f[v].l4=f[v].r4=1;
f[v].pl3=f[v].pr3=0;
}else
{
f[v].l1=f[v].r1=f[v].sum=1;
f[v].pl2=f[v].pr2=0;
f[v].l4=f[v].r4=0;
f[v].pl3=f[v].pr3=1;
}
return;
}
int mid=l+r>>1;
if(qx<=mid) change(v<<1,l,mid); else change(v<<1|1,mid+1,r);
f[v]=merge(f[v<<1],f[v<<1|1]);
}
data find(int v,int l,int r)
{
if(qx<=l && r<=qy) return f[v];
int mid=l+r>>1;
data ff;
bool pd=false;
if(qx<=mid) ff=find(v<<1,l,mid),pd=true;
if(qy>mid) ff=pd?merge(ff,find(v<<1|1,mid+1,r)):find(v<<1|1,mid+1,r);
return ff;
}
int main()
{
freopen("binary.in","r",stdin);
freopen("binary.out","w",stdout);
int n=read();
for(int i=1;i<=n;i++) a[i]=read();
make(1,1,n);
int m=read();
while(m--)
if(read()==1)
{
qx=read();
change(1,1,n);
}else
{
qx=read(),qy=read();
data ans=find(1,1,n);
LL all=((LL)qy-qx+1)*(qy-qx+2)>>1;
printf("%lld\n",all-ans.sum);
}
return 0;
}