题意
有 n n n 次操作,每次操作有四种操作:
- 加入一个数
- 将所有数同加一个数
- 将所有数同减一个数
- 查询第
k
k
k 大数
当一个数小于 m i n min min 是,就立即删除这个数。
最后输出所有删除的数的个数(操作1删除的数除外)。
做法
前置知识:平衡树
暴力
暴力其实很好想,使用数组维护。
- 加入一个数 k k k:将 k k k 加在数组后边。
- 将所有数同加一个数 k k k:循环数组,每个数加 k k k。
- 将所有数同减一个数 k k k:循环数组,每个数减 k k k。
- 查询第 k k k 大数:输出第 k k k 个数。
删除数字可以在每次操作结束后循环数组,把
<
m
i
n
< min
<min 的数都打上删除标记,在进行别的操作时跳过就可以。
总时间复杂度:
O
(
n
2
)
O(n ^ 2)
O(n2)。
平衡树
题目要求维护一个集合(可重集合),很容易想到用平衡树。
- 加入一个数 k k k:将 k k k 加入平衡树。
- 将所有数同加一个数 k k k:这时就要注意一下了,我们不能将所有数都加上一个数,因此我们为所有数打一个懒标记,每次给懒标记加上一个数,在进行操作时把每个数加上懒标记就是原数了。但这里有一个问题,如果先打标记再加入一个数,加入的数没有被修改,而我们在还原原数是把它加上了懒标记,这会导致某些数比原数更大。这个问题也很好解决,我们再插入的时候插入原值 − - − 懒标记,这样再加上懒标记时就抵消了。
- 将所有数同减一个数 k k k:减 k k k ⟺ \iff ⟺ 加 − k -k −k。
- 查询第 k k k 大数:在平衡树中记录子树大小。从根节点开始。如果当前结点左子树大小 ≥ k \ge k ≥k,则答案为左子树中第 k k k 大数。如果当前节点左子树大小 + 1 = k + 1 = k +1=k,则答案为当前节点的权值。如果当前结点左子树大小 + 1 ≤ k + 1 \le k +1≤k,则答案为右子树中第 k − k - k− 当前结点左子树大小 − 1 - 1 −1大数。
删除数字可以在每次操作后从小到大遍历平衡树,看一下当前数是否 < m i n < min <min,如果是就把当前数删除,如果不是就立即返回。(因为删除操作的循环次数与一共删除的数的个数是同一级别的,最多删除 n n n 个数,因此最多循环 n n n 级别次,删除的总时间复杂度为 O ( n log ( n ) O(n\log(n) O(nlog(n))。
总时间复杂度 O ( n log ( n ) O(n\log(n) O(nlog(n)。
平衡树的实现方法有很多,我使用 FHQ treap
实现。
代码
#include <iostream>
#include <cstring>
#include <algorithm>
using namespace std;
const int N = 300010;
int n, m;
struct Node
{
int l, r;
int key, val;
int s;
}tr[N];
int idx, root;
int build(int val)
{
tr[ ++ idx].key = rand();
tr[idx].val = val;
tr[idx].s = 1;
return idx;
}
void pushup(int u)
{
tr[u].s = tr[tr[u].l].s + tr[tr[u].r].s + 1;
}
void split(int u, int val, int& x, int& y)
{
if (!u) x = y = 0;
else
{
if (tr[u].val <= val)
{
x = u;
split(tr[u].r, val, tr[u].r, y);
}
else
{
y = u;
split(tr[u].l, val, x, tr[u].l);
}
pushup(u);
}
}
int merge(int x, int y)
{
if (!x || !y) return x + y;
if (tr[x].key < tr[y].key)
{
tr[x].r = merge(tr[x].r, y);
pushup(x);
return x;
}
else
{
tr[y].l = merge(x, tr[y].l);
pushup(y);
return y;
}
}
void insert(int val)
{
int x, y;
split(root, val, x, y);
root = merge(merge(x, build(val)), y);
}
void erase(int val)
{
int x, y, z;
split(root, val, x, z);
split(x, val - 1, x, y);
y = merge(tr[y].l, tr[y].r);
root = merge(merge(x, y), z);
}
int get_val(int rank)
{
int u = root;
while (true)
{
if (tr[tr[u].l].s >= rank) u = tr[u].l;
else if (tr[tr[u].l].s + 1 == rank) return tr[u].val;
else rank -= tr[tr[u].l].s + 1, u = tr[u].r;
}
}
int get_min()
{
int u = root;
while (tr[u].l) u = tr[u].l;
return tr[u].val;
}
int main()
{
scanf("%d%d", &n, &m);
int x = 0, res = 0;
while (n -- )
{
char s[3];
int k;
scanf("%s%d", s, &k);
if (*s == 'I')
{
if (k >= m) insert(k - x);
}
else if (*s == 'A') x += k;
else if (*s == 'S') x -= k;
else
{
if (tr[root].s < k) puts("-1");
else printf("%d\n", get_val(tr[root].s - k + 1) + x);
}
while (tr[root].s)
{
int v = get_min();
if (v + x < m) res ++, erase(v);
else break;
}
}
printf("%d\n", res);
return 0;
}