题意
给定一个序列,求出其中有多少个区间\([L,R]\),满足在对\([L,R]\)中所有元素排序后,其中相邻元素的差的绝对值不大于\(1\)
解法
我们先考虑该序列是一个排列的情况,也就是每个元素出现且仅出现一次
我们发现,满足条件的区间\([L,R]\)有以下性质:
令\(A\)为\(a_L\)至\(a_R\)中的最大值,\(B\)为最小值
那么\(A-B=R-L\)
并且我们还能发现,对于任意一个区间\(A-B\geq R-L\),可以根据抽屉原理证明,事实上这个结论也是挺显然的
我们考虑枚举右端点,统计左端点的贡献
对于每个\(l\),在线段树的\(l\)位置中插入\(A-B+l\),维护最小值和该最小值出现的次数,每次只要查询根节点即可
每次右端点向右移一位,新添加的\(a[r]\)可能作为之前区间的最大值和最小值,我们发现这个是可以用单调栈维护的,在弹栈的过程中线段树区间修改即可
接下来我们处理序列非排列的情况
现在我们发现,符合条件的区间应该满足以下条件\(A-B=R-L-cnt\),其中\(cnt\)是不是第一次出现在该区间内的数的个数
维护一下就行了
代码
#include <map>
#include <cstdio>
using namespace std;
const int N = 1e6 + 10;
const int oo = 0x3f3f3f3f;
int n;
int a[N], lst[N];
int topmx, topmn;
int stkmx[N], stkmn[N];
map<int, int> col;
template<typename _T> void read(_T &x) {
int c = getchar(); x = 0;
while (c < '0' || c > '9') c = getchar();
while (c >= '0' && c <= '9') x = x * 10 + c - 48, c = getchar();
}
inline int min(int x, int y) { return x < y ? x : y; }
inline int max(int x, int y) { return x > y ? x : y; }
struct SegTree {
#define ls x << 1
#define rs x << 1 | 1
struct node {
int val, sum, tag;
node() { val = oo, sum = tag = 0; }
} t[N << 2];
void updtag(int x, int v) { t[x].val += v, t[x].tag += v; }
void pushdown(int x) {
if (t[x].tag) {
updtag(ls, t[x].tag);
updtag(rs, t[x].tag);
t[x].tag = 0;
}
}
void pushup(int x) {
t[x].val = min(t[ls].val, t[rs].val);
t[x].sum = 0;
if (t[x].val == t[ls].val) t[x].sum += t[ls].sum;
if (t[x].val == t[rs].val) t[x].sum += t[rs].sum;
}
void modify(int x, int l, int r, int ql, int qr, int v) {
if (ql <= l && r <= qr)
return updtag(x, v), void();
int mid = l + r >> 1;
pushdown(x);
if (ql <= mid)
modify(ls, l, mid, ql, qr, v);
if (qr > mid)
modify(rs, mid + 1, r, ql, qr, v);
pushup(x);
}
void change(int x, int l, int r, int k, int v) {
if (l == r)
return t[x].val = 0, t[x].sum = 1, void();
int mid = l + r >> 1;
pushdown(x);
if (k <= mid)
change(ls, l, mid, k, v);
else
change(rs, mid + 1, r, k, v);
pushup(x);
}
#undef ls
#undef rs
} tr;
long long solve() {
long long res = 0;
for (int i = 1; i <= n; ++i) {
while (topmx && a[i] >= a[stkmx[topmx]]) {
tr.modify(1, 1, n, stkmx[topmx - 1] + 1, stkmx[topmx], -a[stkmx[topmx]]);
--topmx;
}
while (topmn && a[i] <= a[stkmn[topmn]]) {
tr.modify(1, 1, n, stkmn[topmn - 1] + 1, stkmn[topmn], a[stkmn[topmn]]);
--topmn;
}
stkmx[++topmx] = stkmn[++topmn] = i;
tr.modify(1, 1, n, stkmx[topmx - 1] + 1, i, a[i]);
tr.modify(1, 1, n, stkmn[topmn - 1] + 1, i, -a[i]);
tr.modify(1, 1, n, lst[i] + 1, i, -1);
tr.change(1, 1, n, i, 0);
if (!tr.t[1].val)
res += tr.t[1].sum;
}
return res;
}
int main() {
read(n);
for (int i = 1; i <= n; ++i) read(a[i]);
for (int i = 1; i <= n; ++i) {
lst[i] = col[a[i]];
col[a[i]] = i;
}
printf("%lld\n", solve());
return 0;
}