题目定义了mex(i,j)表示,没有在i到j之间出现的最小的非负整数。
求所有组合的i,j(i<=j)的和
就是求mex(1,1) + mex(1,2)+....+mex(1,n)
+mex(2,2) + mex(2,3) + ...mex(2,n)
+mex(3,3) + mex(3,4)+...+mex(3,n)
+ mex(n,n)
可以知道mex(i,i),mex(i,i+1)到mex(i,n)是递增的。
首先很容易求得mex(1,1),mex(1,2)......mex(1,n)
因为上述n个数是递增的。
然后使用线段树维护,需要不断删除前面的数。
比如删掉第一个数a[1]. 那么在下一个a[1]出现前的 大于a[1]的mex值都要变成a[1]
因为是单调递增的,所以找到第一个 mex > a[1]的位置,到下一个a[1]出现位置,这个区间的值变成a[1].
然后需要线段树实现区间修改和区间求和。
#include <set>
#include <map>
#include <stack>
#include <queue>
#include <deque>
#include <cmath>
#include <vector>
#include <string>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <iostream>
#include <algorithm>
using namespace std;
#define L(i) i<<1
#define R(i) i<<1|1
#define INF 0x3f3f3f3f
#define pi acos(-1.0)
#define eps 1e-9
#define maxn 100100
#define MOD 1000000007
struct node
{
int l,r;
long long sum;
int mx,lazy;
} tree[maxn*6];
int a[maxn<<1],nxt[maxn<<1];
map<int,int> mp;
int n,mex[maxn<<1];
void pushup(int pos)
{
if(tree[pos].l == tree[pos].r)
return;
tree[pos].sum = tree[pos<<1].sum + tree[pos<<1|1].sum;
tree[pos].mx = max(tree[pos<<1].mx,tree[pos<<1|1].mx);
}
void pushdown(int pos)
{
if(tree[pos].l == tree[pos].r)
return;
if(!tree[pos].lazy)
return;
long long k = (long long)tree[pos].mx;
tree[pos<<1].sum = k*(tree[pos<<1].r - tree[pos<<1].l + 1);
tree[pos<<1].mx = k;
tree[pos<<1].lazy = 1;
tree[pos<<1|1].sum = k*(tree[pos<<1|1].r - tree[pos<<1|1].l + 1);
tree[pos<<1|1].mx = k;
tree[pos<<1|1].lazy = 1;
tree[pos].lazy = 0;
}
void build(int pos,int l,int r)
{
tree[pos].l = l;
tree[pos].r = r;
tree[pos].lazy = 0;
if(l == r)
{
tree[pos].mx = mex[l];
tree[pos].sum = mex[l];
return;
}
int mid = (l + r) >> 1;
build(pos<<1,l,mid);
build(pos<<1|1,mid+1,r);
pushup(pos);
}
void update(int pos,int l,int r,int v)
{
if(tree[pos].l == l && tree[pos].r == r)
{
tree[pos].sum = (long long)v*(tree[pos].r - tree[pos].l + 1);
tree[pos].mx = (long long)v;
tree[pos].lazy = 1;
return ;
}
pushdown(pos);
int mid = (tree[pos].l + tree[pos].r) >> 1;
if(r <= mid)
update(pos<<1,l,r,v);
else if(l > mid)
update(pos<<1|1,l,r,v);
else
{
update(pos<<1,l,mid,v);
update(pos<<1|1,mid+1,r,v);
}
pushup(pos);
}
int query(int pos,int v)
{
if(tree[pos].l == tree[pos].r)
return tree[pos].l;
pushdown(pos);
if(tree[pos<<1].mx > v)
return query(pos<<1,v);
else
return query(pos<<1|1,v);
}
int main()
{
//freopen("in.txt","r",stdin);
//freopen("out.txt","w",stdout);
int t,C = 1;
//scanf("%d",&t);
while(scanf("%d",&n) && n)
{
for(int i = 1; i <= n; i++)
scanf("%d",&a[i]);
mp.clear();
int tmp = 0;
for(int i = 1; i <= n; i++)
{
mp[a[i]] = 1;
while(mp.find(tmp) != mp.end())
tmp++;
mex[i] = tmp;
}
mp.clear();
for(int i = n; i >= 1; i--)
{
if(mp.find(a[i]) == mp.end())
nxt[i] = n + 1;
else
nxt[i] = mp[a[i]];
mp[a[i]] = i;
}
build(1,1,n);
long long sum = 0;
for(int i = 1; i <= n; i++)
{
sum += tree[1].sum;
if(tree[1].mx > a[i])
{
int l = query(1,a[i]);
int r = nxt[i];
if(l < r)
update(1,l,r-1,a[i]);
}
update(1,i,i,0);
}
printf("%lld\n",sum);
}
return 0;
}