Treap平衡树
题目1:253. 普通平衡树
您需要写一种数据结构(可参考题目标题),来维护一些数,其中需要提供以下操作:
- 插入数值x。
- 删除数值x(若有多个相同的数,应只删除一个)。
- 查询数值x的排名(若有多个相同的数,应输出最小的排名)。
- 查询排名为x的数值。
- 求数值x的前驱(前驱定义为小于x的最大的数)。
- 求数值x的后继(后继定义为大于x的最小的数)。
注意: 数据保证查询的结果一定存在。
输入格式
第一行为n,表示操作的个数。
接下来n行每行有两个数opt和x,opt表示操作的序号(1<=opt<=6)。
输出格式
对于操作3,4,5,6每行输出一个数,表示对应答案。
数据范围
n≤100000n≤100000,所有数均在−107−107到107107内。
输入样例:
8 1 10 1 20 1 30 3 20 4 2 2 10 5 25 6 -1
输出样例:
2 20 20 20
代码:
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;//三年竞赛一场空,不开long long见祖宗
//typedef __int128 lll;
#define print(i) cout << "debug: " << i << endsl
#define close() ios::sync_with_stdio(0), cin.tie(0), cout.tie(0)
#define mem(a, b) memset(a, b, sizeof(a))
#define pb(a) push_back(a)
#define x first
#define y second
typedef pair<int, int> pii;
const ll mod = 1e9 + 7;
const int maxn = 1e5 + 10;
const int inf = 0x3f3f3f3f;
struct node
{
int l, r;
int key, val;
int cnt, size;
}t[maxn];
int n;
int root, idx;
void pushup(int p)
{
t[p].size = t[t[p].l].size + t[t[p].r].size + t[p].cnt;
}
int getnode(int key)
{
t[++idx] = {0, 0, key, rand(), 1, 1};
return idx;
}
void zig(int &p) //right handled-rotation
{
int q = t[p].l;
t[p].l = t[q].r, t[q].r = p, p = q;
pushup(t[p].r), pushup(p);
}
void zag(int &p)
{
int q = t[p].r;
t[p].r = t[q].l, t[q].l = p, p = q;
pushup(t[p].l), pushup(p);
}
void build()
{
root = getnode(-inf), t[root].r = getnode(inf);
pushup(root);
if(t[1].val < t[2].val) zag(root);
}
void insert(int &p, int key)
{
if(!p) p = getnode(key);
else if(key == t[p].key) t[p].cnt++;
else if(key < t[p].key)
{
insert(t[p].l, key);
if(t[t[p].l].val > t[p].val) zig(p);
}
else
{
insert(t[p].r, key);
if(t[t[p].r].val > t[p].val) zag(p);
}
pushup(p);
}
void remove(int &p, int key)
{
if(!p) return;
if(t[p].key == key)
{
if(t[p].cnt > 1) t[p].cnt--;
else if(t[p].l || t[p].r)
{
if(!t[p].r || t[t[p].l].val > t[t[p].r].val)
zig(p), remove(t[p].r, key);
else
zag(p), remove(t[p].l, key);
}
else p = 0;
}
else if(key < t[p].key) remove(t[p].l, key);
else remove(t[p].r, key);
pushup(p);
}
int get_rank_by_key(int p, int key)
{
if(!p) return 0; // 本题中不会发生此情况
if(t[p].key == key) return t[t[p].l].size + 1;
if(key < t[p].key) return get_rank_by_key(t[p].l, key);
if(key > t[p].key) return t[t[p].l].size + t[p].cnt + get_rank_by_key(t[p].r, key);
}
int get_key_by_rank(int p, int rank)
{
if(!p) return inf; // 本题中不会发生此情况
if(t[t[p].l].size >= rank) return get_key_by_rank(t[p].l, rank);
if(t[t[p].l].size + t[p].cnt >= rank) return t[p].key;
return get_key_by_rank(t[p].r, rank - t[t[p].l].size - t[p].cnt);
}
int get_prev(int p, int key)
{
if(!p) return -inf;
if(key <= t[p].key) return get_prev(t[p].l, key);
return max(t[p].key, get_prev(t[p].r, key));
}
int get_next(int p, int key)
{
if(!p) return inf;
if(key >= t[p].key) return get_next(t[p].r, key);
return min(t[p].key, get_next(t[p].l, key));
}
int main()
{
build();
scanf("%d", &n);
while(n--)
{
int opt, x;
scanf("%d%d", &opt, &x);
if(opt == 1) insert(root, x);
else if(opt == 2) remove(root, x);
else if(opt == 3) printf("%d\n", get_rank_by_key(root, x) - 1); //考虑到-inf
else if(opt == 4) printf("%d\n", get_key_by_rank(root, x + 1)); //考虑到-inf
else if(opt == 5) printf("%d\n", get_prev(root, x));
else printf("%d\n", get_next(root, x));
}
}
题目2:265. 营业额统计
代码:
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;//三年竞赛一场空,不开long long见祖宗
//typedef __int128 lll;
#define print(i) cout << "debug: " << i << endsl
#define close() ios::sync_with_stdio(0), cin.tie(0), cout.tie(0)
#define mem(a, b) memset(a, b, sizeof(a))
#define pb(a) push_back(a)
#define x first
#define y second
typedef pair<int, int> pii;
const ll mod = 1e9 + 7;
const int maxn = 1e5 + 10;
const int inf = 0x3f3f3f3f;
struct node
{
int l, r;
int key, val;
}t[maxn];
int root, idx;
int get_node(int key)
{
t[++idx] = {0, 0, key, rand()};
return idx;
}
void zig(int &p)
{
int q = t[p].l;
t[p].l = t[q].r, t[q].r = p, p = q;
}
void zag(int &p)
{
int q = t[p].r;
t[p].r = t[q].l, t[q].l = p, p = q;
}
void insert(int &p, int key)
{
if(!p) p = get_node(key);
else if(t[p].key == key) return;
else if(t[p].key < key)
{
insert(t[p].r, key);
if(t[p].val < t[t[p].r].val) zag(p);
}
else
{
insert(t[p].l, key);
if(t[p].val < t[t[p].l].val) zig(p);
}
}
int get_prev(int p, int key) //找到<=key的最大数
{
if(!p) return -inf;
if(t[p].key > key) return get_prev(t[p].l, key);
return max(t[p].key, get_prev(t[p].r, key));
}
int get_next(int p, int key) //找到>=key的最小数
{
if(!p) return inf;
if(t[p].key < key) return get_next(t[p].r, key);
return min(t[p].key, get_next(t[p].l, key));
}
void build()
{
root = get_node(inf), t[root].r = get_node(-inf);
if(t[1].val < t[2].val) zag(root);
}
int main()
{
build();
int t;
scanf("%d", &t);
ll res = 0;
for(int i = 1; i <= t; i++)
{
int x; scanf("%d", &x);
if(i == 1) res += x;
else res += min(abs(x - get_prev(root, x)), abs(x - get_next(root, x)));
insert(root, x);
}
printf("%lld\n", res);
}