题目:
http://www.spoj.com/problems/ORDERSET/en/
题意:
有下面四种操作:
- I x 往集合中插入x,若存在则不操作
- D x 从集合中删除x,若不存在则不操作
- K x 求集合中第x大的数,若x大于集合的大小输出invalid
- C x 统计集合中小于x的数的个数
思路:
简单的平衡树题目,用treap很容易实现,另外学了一波pb_ds试着用了一下
treap:
#include <bits/stdc++.h>
using namespace std;
const int N = 200000 + 10, INF = 0x3f3f3f3f;
struct node
{
int val, pri, sz, son[2];
void init(int _val, int _pri, int _sz)
{
val = _val, pri = _pri, sz = _sz;
son[0] = son[1] = 0;
}
}tr[N];
int treap_root, treap_tot;
void treap_init()
{
treap_root = treap_tot = 0;
tr[0].init(0, 0, 0);
}
void treap_update(int x)
{
tr[x].sz = tr[tr[x].son[0]].sz + tr[tr[x].son[1]].sz + 1;
}
void treap_rotate(int &x, int p)
{
int y = tr[x].son[!p];
tr[x].son[!p] = tr[y].son[p];
tr[y].son[p] = x;
treap_update(x); treap_update(y);
x = y;
}
void treap_insert(int &x, int val)
{
if(! x) tr[x = ++treap_tot].init(val, rand(), 1);
else
{
tr[x].sz++;
int p = val > tr[x].val;
treap_insert(tr[x].son[p], val);
if(tr[x].pri < tr[tr[x].son[p]].pri) treap_rotate(x, !p);
}
}
bool treap_find(int x, int val)
{
if(! x) return false;
if(tr[x].val == val) return true;
int p = val > tr[x].val;
return treap_find(tr[x].son[p], val);
}
void treap_del(int &x, int val)
{
if(tr[x].val == val)
{
if(tr[x].son[0] && tr[x].son[1])
{
int p = tr[tr[x].son[0]].pri > tr[tr[x].son[1]].pri;
treap_rotate(x, p);
treap_del(x, val);
}
else x = tr[x].son[0] + tr[x].son[1];
}
else
{
tr[x].sz--;
int p = val > tr[x].val;
treap_del(tr[x].son[p], val);
}
}
int treap_kth(int x, int k)
{
if(k == tr[tr[x].son[0]].sz + 1) return tr[x].val;
else if(k > tr[tr[x].son[0]].sz + 1) return treap_kth(tr[x].son[1], k - tr[tr[x].son[0]].sz - 1);
else return treap_kth(tr[x].son[0], k);
}
int treap_rank(int x, int val)
{
if(! x) return 0;
if(val == tr[x].val) return tr[tr[x].son[0]].sz;
else if(val > tr[x].val) return tr[tr[x].son[0]].sz + 1 + treap_rank(tr[x].son[1], val);
else return treap_rank(tr[x].son[0], val);
}
int main()
{
int t;
while(~ scanf("%d", &t))
{
treap_init();
char ch;
int val, num = 0;
for(int i = 1; i <= t; i++)
{
scanf(" %c%d", &ch, &val);
if(ch == 'I')
{
if(! treap_find(treap_root, val)) treap_insert(treap_root, val), num++;
}
else if(ch == 'D')
{
if(treap_find(treap_root, val)) treap_del(treap_root, val), num--;
}
else if(ch == 'K')
{
if(num < val) printf("invalid\n");
else printf("%d\n", treap_kth(treap_root, val));
}
else printf("%d\n", treap_rank(treap_root, val));
}
}
return 0;
}
pb_ds:
#include <bits/stdc++.h>
#include <ext/pb_ds/assoc_container.hpp>
#include <ext/pb_ds/tree_policy.hpp>
using namespace std;
using namespace __gnu_pbds;
tree<int, null_type, less<int>, rb_tree_tag, tree_order_statistics_node_update> bst;
//用splay和ov_tree会TLE
//tree<int, null_type, less<int>, splay_tree_tag, tree_order_statistics_node_update> bst;
//tree<int, null_type, less<int>, ov_tree_tag, tree_order_statistics_node_update> bst;
int main()
{
int t, val;
char ch;
scanf("%d", &t);
for(int i = 1; i <= t; i++)
{
scanf(" %c%d", &ch, &val);
if(ch == 'I') bst.insert(val);
else if(ch == 'D') bst.erase(val);
else if(ch == 'K')
{
if(bst.size() >= val) printf("%d\n", *bst.find_by_order(val-1));//返回的是迭代器
else printf("invalid\n");
}
else printf("%d\n", bst.order_of_key(val));//统计比val小的值的个数,是严格小于
}
return 0;
}
用树状数组或者线段树离线也可以做,核心思想就是维护前缀和。首先把数据离散化,对于插入操作,先检查树状数组里面有没有这个值,如果没有则往树状数组里面直接插入,值为1。对于删除操作,往树状数组里面插入,值为-1。对于第k大值,可以二分枚举答案,用枚举值的前缀和来判断当前枚举值是不是第k大。对于统计小于x的值的个数,直接求x-1的前缀和就好了。跑的比treap都快了一点。。。。
#include <bits/stdc++.h>
using namespace std;
const int N = 200000 + 10, INF = 0x3f3f3f3f;
bool vis[N];
struct BIT
{
int n, b[N];
void init(int _n)
{
n = _n;
memset(b, 0, sizeof b);
}
void add(int i, int x)
{
for(; i <= n; i += i & -i) b[i] += x;
}
int sum(int i)
{
int ans = 0;
for(; i >= 1; i -= i & -i) ans += b[i];
return ans;
}
}bit;
char op[N];
int a[N], b[N];
int main()
{
int t;
scanf("%d", &t);
int k = 0;
for(int i = 1; i <= t; i++)
{
scanf(" %c%d", &op[i], &a[i]);
b[++k] = a[i];
}
sort(b + 1, b + 1 + k);
k = unique(b + 1, b + 1 + k) - b - 1;
bit.init(k);
memset(vis, 0, sizeof vis);
int num = 0;
for(int i = 1; i <= t; i++)
{
int tmp = lower_bound(b + 1, b + 1 + k, a[i]) - b;
if(op[i] == 'I')
{
if(! vis[tmp]) bit.add(tmp, 1), num++, vis[tmp] = true;
}
else if(op[i] == 'D')
{
if(vis[tmp]) bit.add(tmp, -1), num--, vis[tmp] = false;
}
else if(op[i] == 'K')
{
if(num < a[i]) puts("invalid");
else
{
int l = 1, r = k, res = 0;
while(l <= r)
{
int mid = (l + r) >> 1;
if(bit.sum(mid) >= a[i]) res = b[mid], r = mid - 1;
else l = mid + 1;
}
printf("%d\n", res);
}
}
else printf("%d\n", bit.sum(tmp - 1));
}
return 0;
}