杭州网赛的题目,看数据范围就知道是线段树。
题目让求所有mex的和,首先可以用o(n)的算法(其实是不超过2*n)预处理出从1到i(i从1到n)的串的mex,然后,依次求2到i(i从2到n)。。。。的mex并求和。
mex有一个性质,即mex【i,j】<=mex【i,j+1】,这个性质可以很容易推得,此题便是利用了这个性质来求解。
当求出下标1开头的所有mex后,求下标2开头的序列时,相当于把原来存在于序列中的第一个数字去掉了,此时,如果这个数字在2到n中没有出现过,那么2到n中所有大于这个数字的mex都可以变成这个数字(因为这个数字在之后没有出现过),如果出现过,那么设这个数字在后面第一次出现的位置为j,那么2到j-1部分内没有出现过这个数字,这个范围内所有大于这个数的mex统统可以变成这个数。依次这样求3,4,。。。。。求和便可以。
此外,由于mex最大为n+1,故对序列中所有大于n的数字都可以不求其下一个出现的位置,因为这个数字不会导致更新。
代码:
#include <iostream>
#include <cstring>
#include <cstdio>
using namespace std;
#define lson l,m,rt<<1
#define rson m+1,r,rt<<1|1
#define delf int m=(l+r)>>1
int next[200020];
int v[200020];
long long int sum[800020]; //求和,注意爆int
int me[200020];
int rank[800080]; //记录是否有更新
int mex[800080]; //记录从当前求的起始位置到i的序列的mex。
int p[200020];
int pre[200020]; //记录该数字之前是否出现过
int n;
void pushup(int rt)
{
sum[rt]=sum[rt<<1]+sum[rt<<1|1];
mex[rt]=mex[rt<<1|1]; //右端点决定序列的mex,右边mex(i~r)大于左边(i~l),
}
void pushdown(int l,int r,int rt)
{
if (rank[rt]==-1)
return ;
delf;
sum[rt<<1]=(long long)(m-l+1)*rank[rt];
sum[rt<<1|1]=(long long)(r-m)*rank[rt];
rank[rt<<1]=rank[rt<<1|1]=rank[rt];
mex[rt<<1]=mex[rt<<1|1]=rank[rt];
rank[rt]=-1;
return ;
}
void build(int l,int r,int rt)
{
rank[rt]=-1;
if (l==r)
{
sum[rt]=me[l];
mex[rt]=me[l];
return ;
}
delf;
build(lson);
build(rson);
pushup(rt);
//cout<<"AAA "<<l<<" "<<r<<" "<<sum[rt]<<" "<<sum[rt<<1]<<" "<<sum[rt<<1|1]<<endl;
}
void update(int L,int R,int l,int r,int rt,int v)
{
if (L<=l&&r<=R)
{
//cout<<l<<" "<<r<<" "<<v<<endl;
sum[rt]=(long long)v*(r-l+1);
rank[rt]=mex[rt]=v;
return ;
}
pushdown(l,r,rt);
delf;
if (L<=m)
update(L,R,lson,v);
if (R>m)
update(L,R,rson,v);
pushup(rt);
//cout<<"AAA "<<l<<" "<<r<<" "<<sum[rt]<<" "<<sum[rt<<1]<<" "<<sum[rt<<1|1]<<endl;
return ;
}
long long int query(int L,int R,int l,int r,int rt)
{
if (L<=l&&r<=R)
return sum[rt];
pushdown(l,r,rt);
delf;
long long int s=0;
if (L<=m)
s=s+query(L,R,lson);
if (R>m)
s=s+query(L,R,rson);
pushup(rt);
return s;
}
int find(int l,int r,int rt,long long int v) //寻找mex大于v[i]的第一次出现的位置,如果没有返回n+1
{
if (mex[rt]<=v)
return n+1;
if (l==r)
return l;
pushdown(l,r,rt);
delf;
if (mex[rt<<1]>v)
return find(lson,v);
else
return find(rson,v);
}
int main()
{
while (scanf("%d",&n)&&n)
{
int mm=0;
memset(sum,0,sizeof(sum));
memset(p,0,sizeof(p));
memset(pre,0,sizeof(pre));
for (int i=1;i<=n;i++)
{
scanf("%d",&v[i]);
if (v[i]<=n)
p[v[i]]=1;
while (p[mm]) //如果这个数字出现过,mm++
mm++;
me[i]=mm;
next[i]=n+1;
}
build(1,n,1);
for (int i=1;i<=n;i++) //求next数组
{
if (v[i]<=n)
{
if (pre[v[i]]==0)
pre[v[i]]=i;
else
{
int c=pre[v[i]];
next[c]=i;
pre[v[i]]=i;
}
}
}
long long int ans=0;
//cout<<sum[1]<<endl;
for (int i=1;i<=n;i++)
{
ans=ans+query(i,n,1,n,1);
int ll=find(1,n,1,v[i]); //寻找第一次出现的位置
//cout<<ll<<" "<<next[i]-1<<endl;
if (ll<next[i]) //如果大于next[i],就无需更新,因为即使当前的数字去掉,后面可能更新仍有一个相同的数字,无法更新
update(ll,next[i]-1,1,n,1,v[i]);
//cout<<ans<<" "<<query(i+1,n,1,n,1)<<endl<<endl;
}
printf("%I64d\n",ans);
}
}