题意:
给一个长度为
n
n
n的
01
01
01串,问你对于所有区间
[
l
,
r
]
[l,r]
[l,r],区间中最长连续
1
1
1的长度之和。
n
<
=
5
e
5
n<=5e5
n<=5e5
题解:
先来吐槽两句。这题当场和舍友一起打的,但是他们去写题了,我只是在边上看题口胡。结果这题我和舍友一起口胡了一个假做法之后就没管。结果寒假训练时阴差阳错又有人搬了这个题,结果一写发现之前做法假了,然后自己又连想好几个假做法。感觉这种可能可行的打开方式比较多的题还是经常找不准打开方式QAQ
回归正题,说做法。打开方式是从左向右扫一遍,每加入一个位置,就统计所有以当前位置为右端点的区间和。我们用线段树维护一个已经遍历过的位置的序列,序列上每个位置的权值表示当前右端点与序列上这个点为左端点构成的区间对答案的贡献,不难发现这个序列的数值是单调不升的,即越靠左的左端点与当前右端点组成的区间的最长连续 1 1 1会越长。我们考虑如果加入的位置数值是 1 1 1,那么会有一段区间的最长连续 1 1 1长度变长,而应该变长的这一段其实就是原来最长连续 1 1 1比当前位置为结尾的最长连续 1 1 1长度小的所有部分。而因为最长连续 1 1 1是连续变化的,所以每次最多让一段序列的贡献 + 1 +1 +1。又因为上面说的单调性,我们只需要记录序列的区间最大值,在线段树上二分即可找到这个边界。然后后面这一段的贡献应该要整体 + 1 +1 +1。然后不论当前位置是 0 0 0还是 1 1 1,每次都在修改后(如果有的话)累加上序列的权值和。这些东西都可以用线段树维护。
于是做完了,注意会爆int。
复杂度 O ( n l o g n ) O(nlogn) O(nlogn)。
代码:
#include <bits/stdc++.h>
using namespace std;
int n,a[500010],lian[500010];
long long ans;
char s[500010];
struct node
{
int l,r;
long long mx,s,tag;
}tr[2000010];
inline void build(int rt,int l,int r)
{
tr[rt].l=l;
tr[rt].r=r;
tr[rt].tag=0;
tr[rt].s=0;
tr[rt].mx=0;
if(l==r)
return;
int mid=(l+r)>>1;
build(rt<<1,l,mid);
build(rt<<1|1,mid+1,r);
}
inline void pushdown(int rt)
{
if(tr[rt].tag)
{
tr[rt<<1].tag+=tr[rt].tag;
tr[rt<<1|1].tag+=tr[rt].tag;
tr[rt<<1].mx+=tr[rt].tag;
tr[rt<<1|1].mx+=tr[rt].tag;
tr[rt<<1].s+=tr[rt].tag*(tr[rt<<1].r-tr[rt<<1].l+1);
tr[rt<<1|1].s+=tr[rt].tag*(tr[rt<<1|1].r-tr[rt<<1|1].l+1);
tr[rt].tag=0;
}
}
inline void update(int rt,int le,int ri)
{
int l=tr[rt].l,r=tr[rt].r;
if(le<=l&&r<=ri)
{
tr[rt].mx++;
tr[rt].tag++;
tr[rt].s+=r-l+1;
return;
}
pushdown(rt);
int mid=(l+r)>>1;
if(le<=mid)
update(rt<<1,le,ri);
if(mid+1<=ri)
update(rt<<1|1,le,ri);
tr[rt].mx=max(tr[rt<<1].mx,tr[rt<<1|1].mx);
tr[rt].s=tr[rt<<1].s+tr[rt<<1|1].s;
}
inline int query(int rt,long long x)
{
int l=tr[rt].l,r=tr[rt].r;
if(l==r)
{
if(tr[rt].mx>=x)
return l+1;
else
return l;
}
pushdown(rt);
int mid=(l+r)>>1;
if(tr[rt<<1|1].mx>=x)
return query(rt<<1|1,x);
else
return query(rt<<1,x);
}
int main()
{
scanf("%d",&n);
scanf("%s",s+1);
for(int i=1;i<=n;++i)
{
if(s[i]=='0')
{
a[i]=0;
lian[i]=0;
}
else
{
a[i]=1;
lian[i]=lian[i-1]+1;
}
}
build(1,1,n);
for(int i=1;i<=n;++i)
{
if(a[i])
{
int lst=query(1,lian[i]);
update(1,lst,i);
}
ans+=tr[1].s;
}
printf("%lld\n",ans);
return 0;
}