题意:求所有子区间Mex的和。Mex是最小的,不存在集合里的非负整数
思路:
首先考虑从1开始的所以Mex值,必然是非递减的,并且可以O(n)求出。(具体看代码)
先考虑如何从1开始的Mex值求出从2开始的Mex值
1.如果Mex<a[1],那么去掉后Mex值不变
2.如果Mex>a[1],并且该区间内不包含另外一个a[x]==a[1],那么Mex=a[1]。
那么我们需要先预处理出next数组,保存下一个a[i]的位置,那么每次把这样一段区间内的大于a[i]的值置为a[i]即可,由于Mex单调,那么我们只要找到位置最靠前的并且大于a[i]的,然后把改位置以后,下一个a[i]的位置之前的值全部置为a[i]即可。
那么我们可以用线段树来维护修改,查询的操作。
#include <iostream>
#include <cstdio>
#include <cstring>
#define ls (t<<1)
#define rs (t<<1|1)
#define midt (tr[t].l+tr[t].r>>1)
using namespace std;
const int maxn=2e5+9;
int a[maxn],now[maxn],next[maxn],first[maxn];
bool flag[maxn];
int n;
struct
{
int l,r,max;
long long sum,lazy;
}tr[maxn<<2];
void pushup(int t)
{
tr[t].sum=tr[ls].sum+tr[rs].sum;
tr[t].max=max(tr[ls].max,tr[rs].max);
}
void maketree(int t,int l,int r)
{
tr[t].l=l;
tr[t].r=r;
tr[t].lazy=-1;
if(l==r)
{
tr[t].sum=now[l];
tr[t].max=now[l];
return ;
}
maketree(ls,l,midt);
maketree(rs,midt+1,r);
pushup(t);
}
void pushdown(int t)
{
if(tr[t].lazy==-1) return;
tr[ls].sum=(tr[ls].r-tr[ls].l+1)*tr[t].lazy;
tr[rs].sum=(tr[rs].r-tr[rs].l+1)*tr[t].lazy;
tr[ls].max=tr[rs].max=tr[t].lazy;
tr[ls].lazy=tr[rs].lazy=tr[t].lazy;
tr[t].lazy=-1;
}
int query(int t,int tmp)
{
if(tr[t].max<tmp) return n+1;
if(tr[t].l==tr[t].r) return tr[t].l;
pushdown(t);
if(tr[ls].max>=tmp) return query(ls,tmp);
return query(rs,tmp);
}
void modify(int t,int l,int r,long long tmp)
{
if(tr[t].l==l&&tr[t].r==r)
{
tr[t].sum=tmp*(r-l+1);
tr[t].lazy=tr[t].max=tmp;
return ;
}
pushdown(t);
if(r<=midt) modify(ls,l,r,tmp);
else if(midt+1<=l) modify(rs,l,r,tmp);
else
{
modify(ls,l,midt,tmp);
modify(rs,midt+1,r,tmp);
}
pushup(t);
}
int main()
{
// freopen("in.txt","r",stdin);
while(scanf("%d",&n),n)
{
for(int i=1;i<=n;i++)
scanf("%d",&a[i]);
memset(flag,0,sizeof(flag));
for(int i=1,top=0;i<=n;i++)
{
if(a[i]<=n)
flag[a[i]]=1;
while(flag[top]) top++;
now[i]=top;
}
for(int i=n;i>=0;i--) first[i]=n+1;
for(int i=n;i>=1;i--)
if(a[i]<=n)
{
next[i]=first[a[i]];
first[a[i]]=i;
}
else next[i]=n+1;
maketree(1,1,n);
long long ans=tr[1].sum;
for(int i=1;i<n;i++)
{
int ll=query(1,a[i]);
if(ll<=next[i]-1)
modify(1,ll,next[i]-1,a[i]);
ans+=tr[1].sum;
}
cout<<ans<<endl;
}
return 0;
}