Description
n≤200000,ai≤109
Analysis
首先可以o(n)求出mex(1,i)(1<=i<=n)
然后,考虑通过mex(1,i)求mex(2,i)
然后,问题变成若删除a[1],对后面造成什么影响。
首先如果后面有一个a[y]=a[1],显然a[y+1]~a[n]是不受影响的。
所以再1~y中,找到mex(1,i)>a[1]的那些,由于mex(1,i)向右单增,所以可以直接用线段树上二分找出位置。
对于那些位置,区间修改将其变为a[1]。
最后统计答案,往后继续删除。
Code
#include<cstdio>
#include<cstring>
#include<algorithm>
#define fo(i,a,b) for(int i=a;i<=b;i++)
using namespace std;
typedef long long ll;
const int N=200010,INF=1e9;
int n,pos,a[N],last[N],left[N],right[N];
ll ans;
bool exist[N];
struct segment
{
int mx,sm,lz;
}tr[N*4];
void down(int v,int l,int r)
{
if(!tr[v].lz) return;
int mid=(l+r)>>1,lz=tr[v].lz-1;
tr[v+v].mx=tr[v+v+1].mx=lz;
tr[v+v].sm=(mid-l+1)*lz,tr[v+v+1].sm=(r-mid)*lz;
tr[v+v].lz=tr[v+v+1].lz=lz+1;
tr[v].lz=0;
}
void change(int v,int l,int r,int x,int y,int z)
{
if(l==x && r==y)
{
tr[v].mx=z,tr[v].sm=(r-l+1)*z,tr[v].lz=z+1;
return;
}
down(v,l,r);
int mid=(l+r)>>1;
if(y<=mid) change(v+v,l,mid,x,y,z);
else
if(x>mid) change(v+v+1,mid+1,r,x,y,z);
else
change(v+v,l,mid,x,mid,z),change(v+v+1,mid+1,r,mid+1,y,z);
tr[v].mx=max(tr[v+v].mx,tr[v+v+1].mx);
tr[v].sm=tr[v+v].sm+tr[v+v+1].sm;
}
int query(int v,int l,int r,int x,int y)
{
if(l==x && r==y) return tr[v].sm;
down(v,l,r);
int mid=(l+r)>>1;
if(y<=mid) return query(v+v,l,mid,x,y);
else
if(x>mid) return query(v+v+1,mid+1,r,x,y);
else
return query(v+v,l,mid,x,mid)+query(v+v+1,mid+1,r,mid+1,y);
}
void find(int v,int l,int r,int z)
{
if(tr[v].mx<=z) return;
if(l==r)
{
pos=min(pos,l);
return;
}
down(v,l,r);
int mid=(l+r)>>1;
if(tr[v+v].mx>z) find(v+v,l,mid,z);
else
if(tr[v+v+1].mx>z) find(v+v+1,mid+1,r,z);
}
void serch(int v,int l,int r,int x,int y,int z)
{
if(tr[v].mx<=z) return;
if(l==x && r==y)
{
find(v,l,r,z);
return;
}
down(v,l,r);
int mid=(l+r)>>1;
if(y<=mid) serch(v+v,l,mid,x,y,z);
else
if(x>mid) serch(v+v+1,mid+1,r,x,y,z);
else
{
serch(v+v,l,mid,x,mid,z);
serch(v+v+1,mid+1,r,mid+1,y,z);
}
}
int main()
{
scanf("%d",&n);
fo(i,1,n)
{
scanf("%d",&a[i]);
if(a[i]>n) a[i]=n+1;
left[i]=last[a[i]],right[last[a[i]]]=i;
last[a[i]]=i;
}
int k=0;
fo(i,1,n)
{
exist[a[i]]=1;
while(exist[k]) k++;
ans+=k;
change(1,1,n,i,i,k);
}
fo(x,1,n-1)
{
int y=right[x]-1,val=a[x];
if(y==-1) y=n;
pos=n+1;
if(x<y) serch(1,1,n,x+1,y,val);
if(pos<=y) change(1,1,n,pos,y,val);
ans+=query(1,1,n,x+1,n);
}
printf("%lld",ans);
return 0;
}