题意
给定两个正整数 q , d q, d q,d,定义三元组 ( i , j , k ) (i, j, k) (i,j,k) 满足 i < j < k , k − i ≤ d i < j < k, k - i \le d i<j<k,k−i≤d,为美丽三元组,现在有一个空集和 q q q 组询问,每次给定一个正整数 x x x,若 x x x 不在集合,那么将 x x x 加入集合,若 x x x 在集合中,那么将 x x x 从集合中删除,每次询问计算集合中美丽三元组的个数。
分析:
考虑每个数从集合加入或删除的贡献,对于一个数 x x x,从区间 [ x , x + d ] [x, x + d] [x,x+d] 中选出任意两个不同的数都可以组成美丽三元组(假设 x x x 为三元组中的最小值),记区间中在集合的数量为 c n t cnt cnt,那么方案数为 ( c n t 2 ) \dbinom{cnt}{2} (2cnt),那么考虑区间 [ x − d , x − 1 ] [x - d, x - 1] [x−d,x−1],对区间中的每个数 i i i,考虑 x x x 加入后的影响,设区间 [ i , i + d ] [i, i + d] [i,i+d] 在集合中的个数为 a i a_i ai,那么美丽三元组的个数为 ( a i 2 ) \dbinom{a_i}{2} (2ai),则 x x x 加入后的美丽三元组数量为 ( a i + 1 2 ) \dbinom{a_i + 1}{2} (2ai+1),设整个集合为 S S S,那么在区间 [ x − d , x − 1 ] [x - d, x - 1] [x−d,x−1] 中新增的美丽三元组数量就为 ∑ i = x − d x − 1 ( ( a i + 1 2 ) − ( a i 2 ) ) [ i ∈ S ] = ∑ i = x − d x − 1 a i [ i ∈ S ] \sum\limits_{i = x - d} ^ {x - 1} \left (\dbinom{a_i + 1}{2} - \dbinom{a_i}{2} \right ) [i \in S] = \sum\limits_{i = x - d} ^ {x - 1}a_i [i \in S] i=x−d∑x−1((2ai+1)−(2ai))[i∈S]=i=x−d∑x−1ai[i∈S],对于 x x x 删除后的影响就为 ∑ i = x − d x − 1 ( ( a i 2 ) − ( a i − 1 2 ) ) [ i ∈ S ] = ∑ i = x − d x − 1 ( a i − 1 ) [ i ∈ S ] \sum\limits_{i = x - d} ^ {x - 1} \left (\dbinom{a_i}{2} - \dbinom{a_i - 1}{2} \right ) [i \in S] = \sum\limits_{i = x - d} ^ {x - 1} (a_i - 1) [i \in S] i=x−d∑x−1((2ai)−(2ai−1))[i∈S]=i=x−d∑x−1(ai−1)[i∈S]
考虑使用线段树,我们重点要维护的是每个数 x x x 在区间 [ x , x + d ] [x, x + d] [x,x+d] 中在集合里的个数,那么每次加入或删除操作就相当于对区间 [ x − d , x − 1 ] [x - d, x - 1] [x−d,x−1] 进行区间 + 1 +1 +1 或 − 1 -1 −1 操作,线段树中维护四个值: cnt \text{cnt} cnt 代表区间里在集合中的数的个数, add \text{add} add 代表区间加的懒标记, val \text{val} val 代表每个数 i i i 在区间 [ i , i + d ] [i, i + d] [i,i+d] 中在集合里的个数, sum \text{sum} sum 代表存在集合中的每个数 i i i 在区间 [ i , i + d ] [i, i + d] [i,i+d] 中在集合里的个数。因为每个数是否存在于集合中由 cnt \text{cnt} cnt 是否为 1 1 1 来决定,相当于 val \text{val} val 是全部的值,也就是说无论区间 [ x − d , x − 1 ] [x - d, x - 1] [x−d,x−1] 的某个数存不存在于集合,我们都要维护,那么真正的答案是 sum \text{sum} sum,也就是那些存在于集合里的数的值,通过懒标记用 cnt × val \text{cnt} \times \text{val} cnt×val 来下传,这样就巧妙地算出了一段区间存在于集合中的数对答案的贡献,至于区间 [ x , x + d ] [x, x + d] [x,x+d] 的贡献可以直接查询 val x \text{val}_x valx 的单点值并给答案贡献 ( val x 2 ) \dbinom{\text{val}_x}{2} (2valx)
代码:
#include <bits/stdc++.h>
#define int long long
using namespace std;
constexpr int N = 2e5;
struct SegmentTree {
struct Info {
int l, r, cnt, add, val, sum;
};
vector<Info> tr;
SegmentTree(int n) : tr(n << 2) {
function<void(int, int, int)> build = [&](int u, int l, int r) {
if (l == r) {
tr[u] = {l, r};
} else {
tr[u] = {l, r};
int mid = l + r >> 1;
build(u << 1, l, mid), build(u << 1 | 1, mid + 1, r);
pushup(u);
}
};
build(1, 1, n);
}
void pushdown(int u) {
if (tr[u].add) {
tr[u << 1].add += tr[u].add, tr[u << 1 | 1].add += tr[u].add;
tr[u << 1].val += (tr[u << 1].r - tr[u << 1].l + 1) * tr[u].add;
tr[u << 1 | 1].val += (tr[u << 1 | 1].r - tr[u << 1 | 1].l + 1) * tr[u].add;
tr[u << 1].sum += tr[u << 1].cnt * tr[u].add;
tr[u << 1 | 1].sum += tr[u << 1 | 1].cnt * tr[u].add;
tr[u].add = 0;
}
}
void pushup(int u) {
tr[u].cnt = tr[u << 1].cnt + tr[u << 1 | 1].cnt;
tr[u].val = tr[u << 1].val + tr[u << 1 | 1].val;
tr[u].sum = tr[u << 1].sum + tr[u << 1 | 1].sum;
}
void modifycnt(int u, int pos, int c) {
if (!pos) return ;
if (tr[u].l == tr[u].r) {
tr[u].cnt += c;
if (!tr[u].cnt) {
tr[u].sum = 0;
} else {
tr[u].sum = tr[u].val;
}
return ;
}
pushdown(u);
int mid = tr[u].l + tr[u].r >> 1;
if (pos <= mid) {
modifycnt(u << 1, pos, c);
} else {
modifycnt(u << 1 | 1, pos, c);
}
pushup(u);
}
void modifysum(int u, int l, int r, int c) {
if (l > r) return ;
if (tr[u].l >= l && tr[u].r <= r) {
tr[u].val += (tr[u].r - tr[u].l + 1) * c;
tr[u].sum += tr[u].cnt * c;
tr[u].add += c;
return ;
}
pushdown(u);
int mid = tr[u].l + tr[u].r >> 1;
if (l <= mid) modifysum(u << 1, l, r, c);
if (r > mid) modifysum(u << 1 | 1, l, r, c);
pushup(u);
}
int askval(int u, int pos) {
if (!pos) return 0;
if (tr[u].l == tr[u].r) return tr[u].val;
pushdown(u);
int mid = tr[u].l + tr[u].r >> 1, res = 0;
if (pos <= mid) {
return askval(u << 1, pos);
} else {
return askval(u << 1 | 1, pos);
}
}
int asksum(int u, int l, int r) {
if (l > r) return 0;
if (tr[u].l >= l && tr[u].r <= r) return tr[u].sum;
pushdown(u);
int mid = tr[u].l + tr[u].r >> 1, res = 0;
if (l <= mid) res += asksum(u << 1, l, r);
if (r > mid) res += asksum(u << 1 | 1, l, r);
return res;
}
};
signed main() {
cin.tie(0) -> sync_with_stdio(0);
int n, d;
cin >> n >> d;
vector<int> st(N + 1);
SegmentTree tr(N + 1);
int ans = 0;
for (int i = 1; i <= n; i ++) {
int x;
cin >> x;
int l = max(1ll, x - d), r = x - 1;
if (!st[x]) {
ans += tr.asksum(1, l, r);
tr.modifysum(1, l, r, 1);
int cnt = tr.askval(1, x);
ans += cnt * (cnt - 1) / 2;
tr.modifycnt(1, x, 1);
} else if (st[x]) {
tr.modifysum(1, l, r, -1);
ans -= tr.asksum(1, l, r);
int cnt = tr.askval(1, x);
ans -= cnt * (cnt - 1) / 2;
tr.modifycnt(1, x, -1);
}
st[x] ^= 1;
cout << ans << "\n";
}
}