题意:求所有子区间Mex的和。Mex是最小的,不存在集合里的非负整数。例如{1,2,4}的Mex等于0。
思路:
记区间[l,r]的Mex值为M(l,r)。
考虑n个数M(1,1),M(1,2),M(1,3)...M(1,n)组成的序列,必然是非递减的,记这样的序列为S(1)。
先求出S(1)(在线段树中插好这n个数),接着要求出M(2,2),M(2,3)...M(2,n),即S(2)。S(1)变成S(2),需要把第一个数去掉(记为为x[1])。现在考虑去掉x[1]后对S(1)的影响。
1.如果M(1,r)<x[1],那么去掉后M(1,r)变成M(2,r),且值不变。
2.如果M(1,r)>x[1],且区间[2,r]中不存在x[1],那么去掉后M(2,r)=x[1]。
于是我们在序列S(1)中找到恰好大于x[1]的位置,记为p。那么M(2,2),M(2,3)...M(2,p-1)的值不变,与M(1,2),M(1,3)..M(1,p-1)一样。然后我们再找到下一个值为x[1]的位置,记为q。那么M(2,p)...M(2,q-1)的值都是x[1](相当于要修改区间[p,q-1]的值了)。
从S(1)到S(n)就是全部的答案了。
用线段树来维护上述操作,就是区间更新,求和,以及查找p。对于q,只要维护一个next数组就好,具体看代码。
code:
#include <algorithm>
#include <iostream>
#include <string.h>
#include <stdlib.h>
#include <stdio.h>
#include <string>
#include <math.h>
#include <vector>
#include <queue>
#include <stack>
#include <cmath>
#include <list>
#include <set>
#include <map>
using namespace std;
/*-------------------------Template----*/
#define N 200020
#define E 100010
#define ll long long
#define CUBE(x) ((x)*(x)*(x))
#define SQ(x) ((x)*(x))
#define ALL(x) x.begin(),x.end()
#define CLR(x,a) memset(x,a,sizeof(x))
#define maxAry(a,n) max_element(a,a+(n))
#define minAry(a,n) min_element(a,a+(n))
typedef pair<int,int> PI;
const int INF=0x3fffffff;
const int PRIME =999983;
const int MOD =10007;
const int MULTI =1000000007;
const double EPS=1e-9;
const int dx[] = {0, 1, 0, -1};
const int dy[] = {1, 0, -1, 0};
inline bool isodd(int x){return x&1;}
/*----------------------end Template----*/
class segTree{
#define lson (rt<<1)
#define rson (rt<<1|1)
#define rtl seg[rt].l
#define rtr seg[rt].r
private:
struct segment{
int l,r,max,mark;
ll sum;
}seg[N<<2];
public:
void setVal(int val,int rt)
{
seg[rt].sum=1ll*(rtr-rtl+1)*val;
seg[rt].max=seg[rt].mark=val;
}
void pushup(int rt)
{
seg[rt].sum=seg[lson].sum+seg[rson].sum;
seg[rt].max=max(seg[lson].max,seg[rson].max);
}
void pushdown(int rt)
{
if(seg[rt].mark!=-1){
setVal(seg[rt].mark,lson);
setVal(seg[rt].mark,rson);
seg[rt].mark=-1;
}
}
void update(int val,int L,int R,int rt)
{
if(L<=rtl && rtr<=R){
setVal(val,rt);
return ;
}
int mid=(rtl+rtr)>>1;
pushdown(rt);
if(L<=mid) update(val,L,R,lson);
if(mid<R) update(val,L,R,rson);
pushup(rt);
}
ll query(int L,int R,int rt)
{
ll ans=0;
if(L<=rtl && rtr<=R)
return seg[rt].sum;
int mid=(rtl+rtr)>>1;
pushdown(rt);
if(L<=mid) ans+=query(L,R,lson);
if(mid<R) ans+=query(L,R,rson);
return ans;
}
int up_bound(int val,int rt)
{
if(rtl==rtr){
if(seg[rt].max<val) return -1;
return rtl;
}
pushdown(rt);
if(val<seg[lson].max)
return up_bound(val,lson);
else
return up_bound(val,rson);
}
void build(int l,int r,int rt)
{
rtl=l, rtr=r;
seg[rt].mark=-1;
if(l==r) return;
int mid=(rtl+rtr)>>1;
build(l,mid,lson);
build(mid+1,r,rson);
}
}T;
int n,a[N],d[N],Next[N];
bool vis[N];
int main()
{
int mex;
ll ans;
while(scanf("%d",&n),n){
CLR(vis,0);
mex=ans=0;
T.build(1,n,1);
for(int i=1;i<=n;i++){
scanf("%d",&a[i]);
a[i]=a[i]>n?n+1:a[i];
vis[a[i]]=true;
if(a[i]==mex)
for(int j=mex+1;j<=n;j++) if(!vis[j]){
mex=j; break;
}
T.update(mex,i,i,1);
ans+=mex;
}
CLR(d,-1);
for(int i=n;i>=1;i--){
if(d[a[i]]==-1) Next[i]=n+1;
else Next[i]=d[a[i]];
d[a[i]]=i;
}
for(int i=2;i<=n;i++){
int l=T.up_bound(a[i-1],1);
if(l!=-1) T.update(a[i-1],l,Next[i-1]-1,1);
ans+=T.query(i,n,1);
}
printf("%I64d\n", ans);
}
return 0;
}