令sec(x)表示mex为x的最小区间,假设当前插入到t,只需维护x = 1 ~ t的sec(x)的左右端点位置即可(这里端点位置定义为在当前已经存在的数中的位置)(sec(0)不产生贡献,不用管)
首先考虑插入操作,假设插入操作t, 修改sec(t)、sec(t + 1)是平凡的, 对于x = 1 ~ t - 1的sec(x),发现如果左端点位置大于当前t的插入位置,则sec(x)左端点要+1,而又有sec(x)左端点关于x单调不增,因此可以用线段树维护;右端点同理
再考虑插入t后,询问当前所有连续子序列mex之和,可以发现一个贡献w的方案可以拆分到sec(1)、sec(2)、... 、sec(w)上,等价于一个sec(x)对应区间[l, r], 则贡献: l * (t - r + 1)
这题卡空间时间...
#include<bits/stdc++.h>
#define pii pair<int,int>
#define fi first
#define sc second
#define pb push_back
#define ll long long
#define trav(v,x) for(auto v:x)
#define all(x) (x).begin(), (x).end()
#define VI vector<int>
#define VLL vector<ll>
#define pll pair<ll, ll>
#define double long double
#define int unsigned int
using namespace std;
const int N = 2097259;
const int inf = 1e9;
//const ll inf = 1e18
const ll mod = 998244353;//1e9 + 7
#ifdef LOCAL
void debug_out(){cerr << endl;}
template<typename Head, typename... Tail>
void debug_out(Head H, Tail... T)
{
cerr << " " << to_string(H);
debug_out(T...);
}
#define debug(...) cerr << "[" << #__VA_ARGS__ << "]:", debug_out(__VA_ARGS__)
#else
#define debug(...) 42
#endif
inline void read(int &x){
char ch = getchar();
x = 0;
while(ch<'0'||ch>'9')ch = getchar();
while('0'<=ch && ch <= '9'){x = x*10+ch-'0';ch = getchar();}
return;
}
struct Node{
ll v[2];
int mxl, mxr;
int dt[2];
}seg[N];
int n;
int a[N / 2];
int tr[N / 2];
void add(ll x)
{
for(;x <= n; x += x & (-x))
tr[x]++;
}
int ask(ll x)
{
ll res = 0;
for(; x; x -= x & (-x))
res += tr[x];
return res;
}
int val[N / 2];
#define mid ((l + r) >> 1)
#define ls (k << 1)
#define rs (k << 1 | 1)
void push_up(int k)
{
for(int i = 0; i < 2; i++)
seg[k].v[i] = seg[ls].v[i] + seg[rs].v[i];
seg[k].mxl = max(seg[ls].mxl, seg[rs].mxl);
seg[k].mxr = max(seg[ls].mxr, seg[rs].mxr);
}
void gao(Node &x, int *dt, ll len)
{
if(dt[0])
{
x.v[0] += dt[0] * len;
x.mxl += dt[0];
x.dt[0] += dt[0];
}
if(dt[1])
{
x.v[1] += dt[1] * len;
x.mxr += dt[1];
x.dt[1] += dt[1];
}
return;
}
void push_down(int k, int l, int r)
{
if(seg[k].dt[0] || seg[k].dt[1])
{
gao(seg[ls], seg[k].dt, mid - l + 1);
gao(seg[rs], seg[k].dt, r - mid);
seg[k].dt[0] = seg[k].dt[1] = 0;
}
}
void upd(int to, int l = 1, int r = n, int k = 1)
{
if(l == r)
{
seg[k].v[0] = 1;
seg[k].v[1] = to;
seg[k].mxl = 1;
seg[k].mxr = to;
return;
}
push_down(k, l, r);
if(to <= mid)
upd(to, l, mid, ls);
else
upd(to, mid + 1, r, rs);
push_up(k);
}
int find_l(int val, int l = 1, int r = n, int k = 1)
{
if(l == r)
{
if(seg[k].mxl < val)
return 0;
return l;
}
push_down(k, l, r);
if(seg[rs].mxl >= val)
return find_l(val, mid + 1, r, rs);
return find_l(val, l, mid, ls);
}
int find_r(int val, int l = 1, int r = n, int k = 1)
{
// cerr << "PP" << l << ' ' <<r << ' ' <<seg[k].mxr << '\n';
if(l == r)
{
if(seg[k].mxr < val)
return inf;
return l;
}
push_down(k, l, r);
if(seg[ls].mxr >= val)
return find_r(val, l, mid, ls);
return find_r(val, mid + 1, r, rs);
}
void cg(int L, int R, int op, int l = 1, int r = n, int k = 1)
{
if(l > R || r < L)
return;
if(L <= l && r <= R)
{
if(op == 0)
{
seg[k].v[0] += r - l + 1;
seg[k].dt[0] ++;
seg[k].mxl++;
}
else
{
seg[k].v[1] += r - l + 1;
seg[k].dt[1] ++;
seg[k].mxr++;
}
return;
}
push_down(k, l, r);
cg(L, R, op, l, mid, ls);
cg(L, R, op, mid + 1, r, rs);
push_up(k);
}
ll calc1(int to, int l = 1, int r = n, int k = 1)
{
if(l == r)
return seg[k].v[1];
push_down(k, l, r);
if(to <= mid)
return calc1(to, l, mid, ls);
return seg[ls].v[1] + calc1(to, mid + 1, r, rs);
}
ll calc0(int to, int l = 1, int r = n, int k = 1)
{
if(l == r)
return seg[k].v[0];
push_down(k, l, r);
if(to > mid)
return calc0(to, mid + 1, r, rs);
return seg[rs].v[0] + calc0(to, l, mid, ls);
}
void sol()
{
// cerr << sizeof seg << '\n';
read(n);
for(int i = 1; i <= n; i++)
read(a[i]);
// n = 1e6;
// for(int i = 1; i <= n; i++)
// a[i] = i - 1;//cin >> a[i];
// random_shuffle(a + 1, a +n + 1);
for(int i = 1; i <= n; i++)
{
ll x = a[i];
ll y = ask(x + 1);
val[x] = y;
add(x + 1);
}
// cerr << "GG" << '\n';
ll ans = 0;
ll res = 0;
for(int i = 0; i < n; i++)
{
int s = val[i];
// cerr <<"!!" << i << ' ' << s << '\n';
if(i > 0)
{
int ps;
ps = find_l(s + 1);
// cerr << ps << '\n';
if(ps >= 1)
res += calc1(ps), cg(1, ps, 0);
ps = find_r(s + 1);
//cerr << ps << '\n';
if(ps <= i)
res += calc0(ps), cg(ps, i, 1);
}
upd(i + 1);
res += i + 1;
ans += seg[1].v[0] * (i + 2) - res;
// cerr << i << ' ' << res << '\n';
}
cout << ans << '\n';
}
signed main()
{
// int tt;
// cin >> tt;
// while(tt--)
sol();
}