题意:
给定长度为n的序列a,
求所有子区间的mex的和.
数据范围:n<=2e5,0<=a(i)<=1e9
解法:
考虑枚举左端点,计算所有右端点的mex,
那么对于每个左端点的所有右端点的mex,全部加起来就是答案.
考虑递推,令mex[i]表示区间[1,i]的mex,
左端点为1的mex[]可以O(1)计算出来.
接下来考虑如何用[1,i]的mex[]计算出[2,i]的mex[].
[1,i]->[2,i]时,要删掉a[1],
考虑删掉a[1]会对哪些mex[i]造成影响,
设val=a[1],下一个val出现的位置是pos,即a[pos]=val,
因为pos以及pos后面的所有位置,都有val了,删掉无影响,
那么[1,pos-1]的mex[]中,大于val的都要变成val,
由于mex[]一定是非递减的,那么找到[1,pos-1]中第一个大于val的位置t,
那么操作变为[t,pos-1]的区间覆盖,覆盖为val.
用线段树做这道题,维护区间max和区间sum,
其中区间max用于树上二分找第一个大于val的位置t.
因为还需要知道下一个val出现的位置,所以还需要预处理一下nt[].
code:
#include <bits/stdc++.h>
using namespace std;
#define int long long
const int maxm=2e5+5;
int nt[maxm];
int a[maxm];
int b[maxm];
int n;
struct Tree{
int ma[maxm<<2];
int sum[maxm<<2];
int laz[maxm<<2];
inline void pp(int node){
sum[node]=sum[node*2]+sum[node*2+1];
ma[node]=max(ma[node*2],ma[node*2+1]);
}
inline void pd(int node,int l,int r){
if(laz[node]!=-1){
int mid=(l+r)/2;
sum[node*2]=(mid-l+1)*laz[node];
sum[node*2+1]=(r-mid)*laz[node];
ma[node*2]=laz[node];
ma[node*2+1]=laz[node];
laz[node*2]=laz[node];
laz[node*2+1]=laz[node];
laz[node]=-1;
}
}
void update(int st,int ed,int val,int l,int r,int node){//区间覆盖
if(st<=l&&ed>=r){
sum[node]=(r-l+1)*val;
ma[node]=val;
laz[node]=val;
return ;
}
pd(node,l,r);
int mid=(l+r)/2;
if(st<=mid)update(st,ed,val,l,mid,node*2);
if(ed>mid)update(st,ed,val,mid+1,r,node*2+1);
pp(node);
}
int ask(int st,int ed,int l,int r,int node){//区间求和
if(st<=l&&ed>=r)return sum[node];
pd(node,l,r);
int mid=(l+r)/2;
int ans=0;
if(st<=mid)ans+=ask(st,ed,l,mid,node*2);
if(ed>mid)ans+=ask(st,ed,mid+1,r,node*2+1);
return ans;
}
int ffind(int val,int l,int r,int node){//树上二分找大于val的第一个位置
if(ma[node]<=val)return n+1;
if(l==r)return l;
pd(node,l,r);
int mid=(l+r)/2;
int ans=n+1;
ans=ffind(val,l,mid,node*2);
if(ans==n+1)ans=ffind(val,mid+1,r,node*2+1);
return ans;
}
void build(int l,int r,int node){//建树
laz[node]=-1;
if(l==r){
sum[node]=b[l];
ma[node]=b[l];
return ;
}
int mid=(l+r)/2;
build(l,mid,node*2);
build(mid+1,r,node*2+1);
pp(node);
}
}T;
signed main(){
ios::sync_with_stdio(0);cin.tie(0);
while(cin>>n&&n){
for(int i=1;i<=n;i++){
cin>>a[i];
}
int now=0;
map<int,int>mark;
for(int i=1;i<=n;i++){
mark[a[i]]++;
while(mark[now])now++;
b[i]=now;
}
mark.clear();
for(int i=n;i>=1;i--){
if(!mark[a[i]]){
nt[i]=n+1;
}else{
nt[i]=mark[a[i]];
}
mark[a[i]]=i;
}
T.build(1,n,1);
int ans=0;
for(int i=1;i<=n;i++){
ans+=T.ask(i,n,1,n,1);
if(i!=n){
int t=T.ffind(a[i],1,n,1);
int pos=nt[i];
if(t<=pos-1){
T.update(t,pos-1,a[i],1,n,1);
}
}
}
cout<<ans<<endl;
}
return 0;
}