题目:
给出一个序列,mex{}表示集合中没有出现的最小的自然数。
然后求 ∑mex(i,j) 。
解析:
思路转载自 cxlove
考虑左端点固定时的所有区间的mex值,这个序列是一个非递减序列,这点首先要明白。
初始时,先求出mex[j]表示mex(1, j)。(可以用map求出)
对于每一个左端点i,就是一个区间求和。(可以利用线段树维护)现在需要考虑的是左端点的改变对于序列的影响。
即左端点i,从 i -> i + 1,mex[j]的改变……,即删去 ai 对于序列的影响。
如果 a[j]=a[i]且j>i,不存在a[k]=a[i](j>k>i) ,那么 j 即 a[i] 下一次出现的位置。(也可利用map,求出 j 的位置)根据mex的定义,我们知道
mex[k](k>=j) 不会改变,因为删掉的 ai 还是存在于序列当中,所以不受影响。之后需要考虑的是 i+1 到 j−1 这段区间的mex{}值。
删去了 ai 之后,使得原先mex{}值大于 ai 的,都会更新成 ai 。
很好理解。因为是没有出现的最小的,然而 ai 更小。之前说过这是一个非递减的序列,所以原先mex值大于 ai 的也是一段连续的区间,所以我们可以找到最靠左的位置r,使得 a[i] < mex[r]。(二分查找最靠左的位置)
那么 r 到 j-1 这段区间的mex值,便会更新为a[i]。所以全部搞定。用线段树维护一下mex序列,区间更新,区间求和,然后一个查找就可以了。
my code
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <map>
#include <set>
#define ls (o<<1)
#define rs (o<<1|1)
#define lson ls, L, M
#define rson rs, M+1, R
#define MID (L + R) >> 1
#define LEN(L, R) ((R) - (L) + 1)
using namespace std;
typedef long long ll;
const int INF = 0x3f3f3f3f;
const int N = 200005;
ll a[N];
ll sumv[N<<2], cov[N<<2];
int n, mex[N], jump[N];
map<ll, int> mp;
inline void pushDown(int o, int L, int R) {
if(cov[o] != -1) {
int M = MID;
cov[ls] = cov[rs] = cov[o];
sumv[ls] = LEN(L, M) * cov[o];
sumv[rs] = LEN(M+1, R) * cov[o];
cov[o] = -1;
}
}
inline void pushUp(int o) {
sumv[o] = sumv[ls] + sumv[rs];
}
void build(int o, int L, int R) {
cov[o] = -1;
sumv[o] = 0;
if(L == R) {
cov[o] = sumv[o] = mex[L];
return ;
}
int M = MID;
build(lson);
build(rson);
pushUp(o);
}
void modify(int o, int L, int R, int ql, int qr, ll val) {
if(ql <= L && R <= qr) {
cov[o] = val;
sumv[o] = LEN(L, R) * val;
return ;
}
pushDown(o, L, R);
int M = MID;
if(ql <= M) modify(lson, ql, qr, val);
if(qr > M) modify(rson, ql, qr, val);
pushUp(o);
}
ll query(int o, int L, int R, int ql, int qr) {
if(ql <= L && R <= qr) return sumv[o];
pushDown(o, L, R);
int M = MID;
ll ret = 0;
if(ql <= M) ret += query(lson, ql, qr);
if(qr > M) ret += query(rson, ql, qr);
return ret;
}
ll get(int o, int L, int R, int pos) {
if(L == R) return sumv[o];
pushDown(o, L, R);
int M = MID;
if(pos <= M) return get(lson, pos);
else return get(rson, pos);
}
void getMex() {
mp.clear();
int tmp = 0;
for(int i = 1; i <= n; i++) {
mp[a[i]] = 1;
while(mp.find(tmp) != mp.end())
tmp++;
mex[i] = tmp;
}
mp.clear();
for(int i = n; i >= 1; i--) {
if(mp.find(a[i]) == mp.end())
jump[i] = n+1;
else jump[i] = mp[a[i]];
mp[a[i]] = i;
}
}
int search(int start, int end, int lim) {
int L = start, R = end+1;
while(L < R) {
int M = MID;
ll tmp = get(1, 1, n, M);
if(tmp > lim) R = M;
else L = M + 1;
}
return L;
}
ll cal() {
int ql, qr;
ll ret = query(1, 1, n, 1, n);
for(int i = 2; i <= n; i++) {
qr = jump[i-1] - 1;
ql = search(i, qr, a[i-1]);
if(ql <= qr)
modify(1, 1, n, ql, qr, a[i-1]);
ret += query(1, 1, n, i, n);
}
return ret;
}
int main() {
while(~scanf("%d", &n) && n) {
for(int i = 1; i <= n; i++) {
scanf("%lld", &a[i]);
}
getMex();
build(1, 1, n);
printf("%lld\n", cal());
}
return 0;
}