Mex
Time Limit : 15000/5000ms (Java/Other) Memory Limit : 65535/65535K (Java/Other)
Total Submission(s) : 18 Accepted Submission(s) : 3
Problem Description
Mex is a function on a set of integers, which is universally used for impartial game theorem. For a non-negative integer set S, mex(S) is defined as the least non-negative integer which is not appeared in S. Now our problem is about mex function on a sequence.
Consider a sequence of non-negative integers {ai}, we define mex(L,R) as the least non-negative integer which is not appeared in the continuous subsequence from aL to aR, inclusive. Now we want to calculate the sum of mex(L,R) for all 1 <= L <= R <= n.
Consider a sequence of non-negative integers {ai}, we define mex(L,R) as the least non-negative integer which is not appeared in the continuous subsequence from aL to aR, inclusive. Now we want to calculate the sum of mex(L,R) for all 1 <= L <= R <= n.
Input
The input contains at most 20 test cases. For each test case, the first line contains one integer n, denoting the length of sequence. The next line contains n non-integers separated by space, denoting the sequence. (1 <= n <= 200000, 0 <= ai <= 10^9) The input ends with n = 0.
Output
For each test case, output one line containing a integer denoting the answer.
Sample Input
3 0 1 3 5 1 0 2 0 1 0
Sample Output
5 24 [hint] For the first test case: mex(1,1)=1, mex(1,2)=2, mex(1,3)=2, mex(2,2)=0, mex(2,3)=0,mex(3,3)=0. 1 + 2 + 2 + 0 +0 +0 = 5. [/hint]
Source
2013 ACM/ICPC Asia Regional Hangzhou Online
代码:
#include <iostream>
#include <cstdio>
#include <cstring>
#include <set>
#include <map>
using namespace std;
typedef long long LL;
#define LL(x) (x<<1)
#define RR(x) (x<<1|1)
#define MID(a,b) (a+((b-a)>>1))
const int N=200005;
set<int> order;
map<int,int> H;
int a[N],pre[N],pos[N],valu[N];
struct Segtree
{
LL sum[N*4];
int lKey[N*4],rKey[N*4],delay[N*4];
void PushUp(int ind)
{
lKey[ind]=lKey[LL(ind)];
rKey[ind]=rKey[RR(ind)];
sum[ind]=sum[LL(ind)]+sum[RR(ind)];
}
void fun(int valu,int lft,int rht,int ind)
{
lKey[ind]=rKey[ind]=delay[ind]=valu;
sum[ind]=valu*(rht-lft+1);
}
void PushDown(int ind,int lft,int rht)
{
if(delay[ind]!=-1)
{
int mid=MID(lft,rht);
fun(delay[ind],lft,mid,LL(ind));
fun(delay[ind],mid+1,rht,RR(ind));
delay[ind]=-1;
}
}
void build(int lft,int rht,int ind)
{
sum[ind]=lKey[ind]=rKey[ind]=0; delay[ind]=-1;
if(lft==rht) lKey[ind]=rKey[ind]=sum[ind]=valu[lft];
else
{
int mid=MID(lft,rht);
build(lft,mid,LL(ind));
build(mid+1,rht,RR(ind));
PushUp(ind);
}
}
void updata(int st,int ed,int valu,int lft,int rht,int ind)
{
if(st<=lft&&rht<=ed) fun(valu,lft,rht,ind);
else
{
PushDown(ind,lft,rht);
int mid=MID(lft,rht);
if(st<=mid) updata(st,ed,valu,lft,mid,LL(ind));
if(ed> mid) updata(st,ed,valu,mid+1,rht,RR(ind));
PushUp(ind);
}
}
int query(int st,int ed,int valu,int lft,int rht,int ind)
{
if(lft==rht)
{
if(lKey[ind]>=valu) return lft;
return -1;
}
else
{
int mid=MID(lft,rht),pos=-1;
PushDown(ind,lft,rht);
if(ed<=mid) pos=query(st,ed,valu,lft,mid,LL(ind));
else if(st>mid) pos=query(st,ed,valu,mid+1,rht,RR(ind));
else
{
if(lKey[RR(ind)]>=valu) pos=query(st,ed,valu,mid+1,rht,RR(ind));
else pos=query(st,ed,valu,lft,mid,LL(ind));
}
PushUp(ind);
return pos;
}
}
int getValu(int pos,int lft,int rht,int ind)
{
if(lft==rht) return lKey[ind];
else
{
int mid=MID(lft,rht),tmp=0;
PushDown(ind,lft,rht);
if(pos<=mid) tmp=getValu(pos,lft,mid,LL(ind));
else tmp=getValu(pos,mid+1,rht,RR(ind));
PushUp(ind);
return tmp;
}
}
}seg;
int main()
{
int n;
while(scanf("%d",&n)!=EOF)
{
if(n==0) break;
int sc=0;
order.clear(); H.clear();
memset(pre,-1,sizeof(pre));
for(int i=0;i<n;i++)
{
scanf("%d",&a[i]);
if(H.find(a[i])==H.end()) { H.insert(make_pair(a[i],sc++)); }
else { pre[i]=pos[ H[a[i]] ]; }
pos[H[a[i]]]=i;
}
int tmp=0;
for(int i=n-1;i>=0;i--)
{
order.insert(a[i]);
while(order.find(tmp)!=order.end()) tmp++;
valu[i]=tmp;
}
seg.build(0,n-1,1);
LL ans=0;
for(int i=n-1;i>=0;i--)
{
ans+=seg.sum[1];
int st=pre[i]+1,ed=i-1,pos=-1;
if(st<=ed) pos=seg.query(st,ed,a[i],0,n-1,1);
if(pos!=-1) seg.updata(st,pos,a[i],0,n-1,1);
seg.updata(i,n-1,0,0,n-1,1);
}
printf("%lld\n",ans);
}
return 0;
}
错误代码::
#include <cstdio>
#include <cstring>
#include <iostream>
#include <algorithm>
#include <vector>
#include <queue>
#include <set>
#include <map>
#include <string>
#include <math.h>
#include <stdlib.h>
#include <time.h>
#define L(rt) (rt<<1)
#define R(rt) (rt<<1|1)
using namespace std;
typedef long long ll;
const int maxn=200010;
int mex[maxn];
struct Node
{
int l,r;
ll sum;
int maxx;
int lazy;
}tree[maxn*3];
void UpdateSame(int i,int v)
{
tree[i].sum =(ll) v * (tree[i].r - tree[i].l + 1);
tree[i].maxx = v;
tree[i].lazy = 1;
}
void pushup(int i)
{
if(tree[i].l == tree[i].r)return;
tree[i].sum =tree[L(i)].sum + tree[R(i)].sum;
tree[i].maxx = max(tree[L(i)].maxx,tree[R(i)].maxx);
}
void pushdown(int i)
{
if(tree[i].l==tree[i].r)return;
if(tree[i].lazy)
{
UpdateSame(L(i),tree[i].maxx);
UpdateSame(R(i),tree[i].maxx);
tree[i].lazy = 0;
}
}
void Build(int i,int l,int r)
{
tree[i].l = l;
tree[i].r = r;
tree[i].lazy = 0;
if(l == r)
{
tree[i].maxx=mex[l];
tree[i].sum=mex[l];
return;
}
int mid = (l + r)>>1;
Build(L(i),l,mid);
Build(R(i),mid+1,r);
pushup(i);
}
void Update(int i,int l,int r,int v)
{
if(tree[i].l==l&&tree[i].r==r)
{
UpdateSame(i,v);
return;
}
pushdown(i);
int mid=(tree[i].l+tree[i].r)>>1;
if(r<=mid)
{
Update(L(i),l,r,v);
}
else if(l > mid)
{
Update(R(i),l,r,v);
}
else
{
Update(L(i),l,mid,v);
Update(R(i),mid+1,r,v);
}
pushup(i);
}
int query(int i,int v)
{
if(tree[i].l==tree[i].r)
return tree[i].l;
pushdown(i);
if(tree[L(i)].maxx>v)
return query(L(i),v);
else return query(R(i),v);
}
int a[maxn];
map<int,int>mp;
int next[maxn];
int main()
{
int n;
while(scanf("%d",&n) && n)
{
for(int i = 1;i <= n;i++)
scanf("%d",&a[i]);
mp.clear();
int tmp = 0;
for(int i = 1;i <= n;i++)
{
mp[a[i]] = 1;
while(mp.find(tmp)!=mp.end())
tmp++;
mex[i]=tmp;
}
mp.clear();
for(int i = n;i >= 1;i--)
{
if(mp.find(a[i])==mp.end())
next[i]=n+1;
else next[i]=mp[a[i]];
mp[a[i]]=i;
}
Build(1,1,n);
ll sum = 0;
for(int i = 1;i <= n;i++)
{
sum+=tree[1].sum;
if(tree[1].maxx>a[i])
{
int l=query(1,a[i]);
int r=next[i];
if(l<r)
Update(1,l,r-1,a[i]);
}
Update(1,i,i,0);
}
printf("%d\n",sum);
}
return 0;
}