题意:有一个序列a[],mex(L, R)表示区间a在区间[L, R]上第一个没出现的最小非负整数,对于序列a[],求所有的mex(L, R)的和(1 <= L <= R <= n,1 <= n <= 200000,0 <= ai <= 10^9)。
题目链接:http://acm.hdu.edu.cn/showproblem.php?pid=4747
——>>线段树就是如此的神~
求出所有的mex(1, i);接着删去第1个结点,就是所有的mex(2, i);接着再删去第1个结点,就是所有的mex(3, i);……最后就是mex(n, n),求和即是答案。
而维护删除结点后的信息,正是线段树的拿手好戏。
对于每个线段树结点(o, L, R),设mexv[o]表示mex(left, R),这里的left表示第一个数的下标,初始1,随着删除的进行,left递增。
设sumv[o]表示区间[L, R]上的所有mexv的和。
当删除了一个结点a[i]时,如果a[i] < mexv[1],说明a[i]被删后一定会使某个区间的mex变成a[i],这个区间就是第一个mexv比a[i]大的i到下个a[i]出现的前一位。
~~
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
const int N = 200000;
const int maxn = 200000 + 10;
int n, a[maxn], vis[maxn], nxt[maxn], setv[maxn<<2]; //nxt[i]表示下一个a[i]出现的位置,没有为n+1
long long mex1[maxn], mexv[maxn<<2], sumv[maxn<<2]; //mex1[i]表示mex(1, i)
void read(){
for(int i = 1; i <= n; i++){
scanf("%d", &a[i]);
if(a[i] >= N) a[i] = N;
}
}
void init(){
memset(vis, 0, sizeof(vis));
}
void getMex1(){ //获取mex1[]
int ret = 0; //因为mex1[]是递增的,所以ret = 0放在for的外面(放里面就O(n^2)了)
for(int i = 1; i <= n; i++){
vis[a[i]] = 1;
for(; vis[ret]; ret++);
mex1[i] = ret;
}
}
void getNxt(){ //获取nxt[]
for(int i = 0; i <= N; i++) vis[i] = n+1; //初始化为n+1
for(int i = n; i >= 1; i--){ //注意:从右往左!!!
nxt[i] = vis[a[i]];
vis[a[i]] = i;
}
}
void maintain(int o, int L, int R){ //维护函数
int lc = o << 1, rc = lc | 1;
mexv[o] = max(mexv[lc], mexv[rc]);
sumv[o] = sumv[lc] + sumv[rc];
}
void build(int o, int L, int R){ //建树
setv[o] = -1; //赋值标记
if(L == R){
mexv[o] = sumv[o] = mex1[L];
return;
}
int M = (L + R) >> 1;
build(o<<1, L, M);
build(o<<1|1, M+1, R);
maintain(o, L, R);
}
inline void get(int o, int L, int R, int v){ //单点赋值
mexv[o] = v;
sumv[o] = v * (R - L + 1);
setv[o] = v;
}
void pushdown(int o, int L, int R){ //下传机制
if(setv[o] != -1){
int M = (L + R) >> 1;
int lc = o << 1, rc = lc | 1;
get(lc, L, M, setv[o]);
get(rc, M+1, R, setv[o]);
setv[o] = -1;
}
}
int Upper_bound(int o, int L, int R, int v){ //找出第一个mexv比v大的下标
if(L == R) return L;
pushdown(o, L, R);
int M = (L + R) >> 1;
int lc = o << 1, rc = lc | 1;
return mexv[lc] > v ? Upper_bound(lc, L, M, v) : Upper_bound(rc, M+1, R, v);
}
void update(int o, int L, int R, int ql, int qr, int v){ //区间赋值:[ql, qr]赋为v
if(ql <= L && R <= qr){
get(o, L, R, v);
return;
}
pushdown(o, L, R);
int M = (L + R) >> 1;
int lc = o << 1, rc = lc | 1;
if(ql <= M) update(lc, L, M, ql, qr, v);
if(qr > M) update(rc, M+1, R, ql, qr, v);
maintain(o, L, R);
}
void solve(){ //解决函数
long long ret = 0;
for(int i = 1; i <= n; i++){ //枚举起点
ret += sumv[1]; //累加
if(a[i] < mexv[1]){ //这种情况下删了a[i]会使区间[Upper_bound, nxt[i]-1]的mex变成a[i]
int ql = Upper_bound(1, 1, n, a[i]), qr = nxt[i] - 1;
if(ql <= qr) update(1, 1, n, ql, qr, a[i]); //这个判断是必要的,若有数据:2 1 2 0 0,删第一个2时
}
int ql = i, qr = i;
update(1, 1, n, ql, qr, 0); //删除起点产生新起点
if(!mexv[1]) break; //剪枝
}
printf("%I64d\n", ret);
}
int main()
{
while(scanf("%d", &n) == 1 && n){
read();
init();
getMex1();
getNxt();
build(1, 1, n);
solve();
}
return 0;
}