题意:求所有子区间的Mex的和。(Mex的含义有解释)
思路:
暴力对子区间的枚举肯定不行,只能从不同区间之间的Mex的关系去着手。
首先可以考虑固定子区间的起点,也就是L = 1,R = i,1<=i<=n,[L,R]的区间,由Mex的定义可以注意到随着R递增时,Mex一定是一个非递减序列。而且这一步可以很简单的直接递推初始化出所有L = 1的区间的Mex。
接下来考虑随着起点的变化,Mex会因为什么的影响而变化。对于当前考虑区间设起点L = i。
对于R = j,j>=i的子区间[L,R]。
1.若[L,R]存在k使得 num[k] == num[i-1],则很明显[L,R]的Mex和[L-1,R]的Mex是一样的,因为少了num[i-1]没有影响。
2.若上述不成立,则若mex[L-1,R] > num[i-1],则很明显mex[L,R] = num[i-1],因为mex总是取较小的不存在的。
3.若其他情况,则mex[L,R] = mex[L-1,R]
将上述一般情况综合起来考虑,其实就是每次枚举的起点在往后递推的时候,会“扔掉”num[i],然后在其后找到下一个和num[i]相同的值所在的位置j(若不存在,取为n+1),然后考虑区间[i+1,j-1]中,大于num[i]的mex值会发生改变,而mex的值一直是非递减的,所以只需要修改一段连续的区间,可以考虑用线段树来进行修改并区间求和。还需要一个最值来查询第一大于num[i]的mex的位置。
代码如下:
#include <cstdio>
#include <cstring>
#include <iostream>
#include <algorithm>
#include <cstdlib>
using namespace std;
#define lson root<<1,l,mid
#define rson root<<1|1,mid+1,r
const int maxn = 200008;
long long sum[maxn<<2];
int setv[maxn<<2],maxv[maxn<<2];
int mex[maxn],vis[maxn];
int num[maxn],next[maxn],pre[maxn];
void pushup(int root){
sum[root] = sum[root<<1] + sum[root<<1|1];
maxv[root] = max(maxv[root<<1] ,maxv[root<<1|1]);
}
void pushdown(int root,int l,int r){
if(setv[root] != -1){
setv[root<<1] = setv[root<<1|1] = setv[root];
int mid = (l+r)>>1;
sum[root<<1] = (long long)(mid-l+1)*setv[root];
sum[root<<1|1] = (long long)(r-mid)*setv[root];
maxv[root<<1] = maxv[root<<1|1] = setv[root];
setv[root] = -1;
}
}
void build(int root,int l,int r){
setv[root] = -1;
if(l==r){
maxv[root] = mex[l];
sum[root] = (long long)mex[l];
return ;
}
int mid = (l+r)>>1;
build(lson);
build(rson);
pushup(root);
}
void update(int p,int ll,int rr,int root,int l,int r){
if(ll>rr)
return ;
if(ll<=l&&r<=rr){
maxv[root] = setv[root] = p;
sum[root] = (long long)(r-l+1)*p;
return ;
}
pushdown(root,l,r);
int mid = (l+r)>>1;
if(ll<=mid)
update(p,ll,rr,lson);
if(rr>mid)
update(p,ll,rr,rson);
pushup(root);
}
int querypos(int val,int root,int l,int r){
if(maxv[root]<=val)
return r+1;
if(l==r)
return l;
pushdown(root,l,r);
int mid = (l+r)>>1;
if(maxv[root<<1] > val )
return querypos(val,lson);
else
return querypos(val,rson);
}
long long query(int ll,int rr,int root,int l,int r){
if(ll<=l&&r<=rr){
return sum[root];
}
pushdown(root,l,r);
int mid = (l+r)>>1;
long long ret = 0;
if(ll<=mid)
ret+=query(ll,rr,lson);
if(rr>mid)
ret+=query(ll,rr,rson);
return ret;
}
int main(){
int n;
while(scanf("%d",&n),n){
for(int i = 1;i<=n;i++){
scanf("%d",&num[i]);
}
int cnt = 0;
memset(mex,0,sizeof(mex));
memset(vis,0,sizeof(vis));
memset(pre,0,sizeof(pre));
for(int i = 1;i<=n;i++){
if(num[i]<maxn)
vis[num[i]] = 1;
//初始化mex函数值
for(int j = mex[i-1];;j++){
if(!vis[j]){
mex[i] = j;
break;
}
}
//标记每一个和num[i]相等的下一个num[j]的位置
if(num[i]<maxn){
next[pre[num[i]]] = i;
next[i] = n+1;
pre[num[i]] = i;
}
}
build(1,1,n);
long long ans = 0ll;
for(int l = 1;l<=n;l++){
ans+=query(l,n,1,1,n);
//选取更新区间
int pos = querypos(num[l],1,1,n);
if(num[l]<maxn){
update(num[l],pos,next[l]-1,1,1,n);
}
}
printf("%I64d\n",ans);
}
return 0;
}