将所有的数换种写法
mex = mex(1,1)+mex(1,2)+...+mex(1,n)+mex(2,2)+...+mex(2,n)+...mex(n,n).
应该知道如何算出 左端点不动的mex 也就是 mex(i,i) +....+mex(i,n)...
那么如何进行下一次 就要删除左端点。更新的时候也好想到。
但是要知道 mex(i,i) ...mex(i,n) 是递增的。 所以当mex(i,n) > a[i] 的时候 就要把 a[i] 到下一个 a[i] 出现的之间全部赋值为a[i]
比如
4
0 1 3 0
mex=2 但是 mex(1,1)=1, mex(1,2)=2, mex(1,3)=2;mex(1,4)=2;
如果此时去掉了0
那么 mex(2,2)=mex(2,3)就变成了0 因为缺掉了一个较小的数,那么当然就是这个小的数成为了mex
#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <map>
#define maxn 200005
#define lson num<<1,s,mid
#define rson num<<1|1,mid+1,e
#define inf 0x3f3f3f3f
using namespace std;
typedef long long LL;
LL sum[maxn<<2];
int mex[maxn<<2];
int cov[maxn<<2];
int n;
int a[maxn],next[maxn],mx[maxn],cnt;
map<int,int>mp;
void pushup(int num)
{
sum[num]=sum[num<<1]+sum[num<<1|1];
mex[num]=max(mex[num<<1],mex[num<<1|1]);
}
void pushdown(int num,int s,int e)
{
if(cov[num])
{
int mid=(s+e)>>1;
sum[num<<1]=(LL)mex[num]*(mid-s+1);
sum[num<<1|1]=(LL)mex[num]*(e-mid);
mex[num<<1]=mex[num<<1|1]=mex[num];
cov[num<<1]=cov[num<<1|1]=cov[num];
cov[num]=0;
}
}
void build(int num,int s,int e)
{
cov[num]=0;
if(s==e)
{
sum[num]=mex[num]=mx[s];
return ;
}
int mid=(s+e)>>1;
build(lson);
build(rson);
pushup(num);
}
void update(int num,int s,int e,int l,int r,int v)
{
if(l<=s && r>=e)
{
cov[num]=1;
sum[num]=(LL)v*(e-s+1);
mex[num]=v;
return;
}
pushdown(num,s,e);
int mid=(s+e)>>1;
if(l<=mid)update(lson,l,r,v);
if(r>mid)update(rson,l,r,v);
pushup(num);
}
int query(int num,int s,int e,int v)
{
if(s==e)return s;
int mid=(s+e)>>1;
pushdown(num,s,e);
if(mex[num<<1]>v)return query(lson,v);
else return query(rson,v);
}
int main()
{
while(scanf("%d",&n)!=EOF && n)
{
for(int i=1;i<=n;i++)scanf("%d",&a[i]);
mp.clear();
int tmp=0;
for(int i=1;i<=n;i++)
{
mp[a[i]]=1;
while(mp.find(tmp)!=mp.end())tmp++;
mx[i]=tmp;
}
mp.clear();
for(int i=n;i>=1;i--)
{
if(mp.find(a[i])==mp.end())next[i]=n+1;
else next[i]=mp[a[i]];
mp[a[i]]=i;
}
cnt=1;
build(1,1,n);
LL ans=0;
for(int i=1;i<=n;i++)
{
ans+=sum[1];
// printf("mex = %d a[i] = %d sum = %I64d\n",mex[1],a[i],sum[1]);
if(mex[1]>a[i])
{
int l=query(1,1,n,a[i]);
int r=next[i];
if(l<r)update(1,1,n,1,r-1,a[i]);
}
update(1,1,n,i,i,0);
}
printf("%I64d\n",ans);
}
return 0;
}