模板:求第k小的key值 和 求数key是第k小的k值
struct Node {
int size;
int rank;
int key;
Node* lson, * rson;
Node(int x) {
lson = rson = NULL;
rank = rand();
key = x;
size = 1;
}
};
int getSize(Node* o) {
if (o == NULL)return 0;
return o->size;
}
void L_rotate(Node*& o) {
Node* k = o->rson;
o->rson = k->lson;
k->lson = o;
k->size = o->size;
o->size = getSize(o->lson) + getSize(o->rson) + 1;
o = k;
}
void R_rotate(Node*& o) {
Node* k = o->lson;
o->lson = k->rson;
k->rson = o;
k->size = o->size;
o->size = getSize(o->lson) + getSize(o->rson) + 1;
o = k;
}
void insert(Node*& o, int x) {
if (o == NULL) o = new Node(x);
else {
o->size++;
if (o->key < x) {
insert(o->rson, x);
if (o->rank < o->rson->rank)
L_rotate(o);
}
else {
insert(o->lson, x);
if (o->rank < o->lson->rank)
R_rotate(o);
}
}
}
int kth(Node* o, int k) {
if (o == NULL)return -1;
int temp = getSize(o->lson) + 1;
if (temp == k)return o->key;
if (temp > k)return kth(o->lson, k);
return kth(o->rson, k - temp);
}
int find(Node* o, int key) {
if (o == NULL)return -1;
if (o->key == key)return o->lson == NULL ? 1 : o->lson->size + 1;
if (o->key > key)return find(o->lson, key);
int temp = find(o->rson, key);
if (temp == -1)return -1;
return o->lson == NULL ? temp + 1 : temp + o->lson->size + 1;
}
void remove(Node*& o) {
if (o == NULL)return;
Node* temp = o;
if (o->lson == NULL) { o = o->rson; delete temp; }
else if (o->rson == NULL) { o = o->lson; delete temp; }
else {
if (o->lson->rank > o->rson->rank) {
R_rotate(o); remove(o->rson);
}
else {
L_rotate(o); remove(o->lson);
}
}
}
void merge(Node* a, Node* b) {
if (b->lson != NULL)merge(a, b->lson);
if (b->rson != NULL)merge(a, b->rson);
insert(a, b->key); delete b; b = NULL;
}
还有第二种模板,修改起来方便,但是更慢一点,如下:
struct Node{
int size;
int rank;
int key;
Node* son[2];
bool operator < (const Node& a) const { return rank < a.rank; }
int cmp(int x)const {
if (x == key)return -1;
return x < key ? 0 : 1;
}
void update() {
size = 1;
if (son[0] != NULL)size += son[0]->size;
if (son[1] != NULL)size += son[1]->size;
}
};
void rotate(Node*& o, int d) { // d=0,左旋;d=1,右旋
Node* k = o->son[d ^ 1];
o->son[d ^ 1] = k->son[d];
k->son[d] = o;
o->update();
k->update();
o = k;
}
void insert(Node*& o, int x) {
if (o == NULL) {
o = new Node();
o->son[0] = o->son[1] = NULL;
o->rank = rand();
o->key = x;
o->size = 1;
}
else {
int d = o->cmp(x);
insert(o->son[d], x);
o->update();
if (o < o->son[d])
rotate(o, d ^ 1);
}
}
int kth(Node* o, int k) {
if (o == NULL || k <= 0 || k > o->size)
return -1;
int s = o->son[1] == NULL ? 0 : o->son[1]->size;
if (k == s + 1)return o->key;
else if (k <= s)return kth(o->son[1], k);
else return kth(o->son[0], k - s - 1);
}
int find(Node* o, int k) {
if (o == NULL)
return -1;
int d = o->cmp(k);
if (d == -1)
return o->son[1] == NULL ? 1 : o->son[1]->size + 1;
else if (d == 1)return find(o->son[d], k);
else {
int tmp = find(o->son[d], k);
if (tmp == -1)return -1;
else
return o->son[1] == NULL ? tmp + 1 : tmp + 1 + o->son[1]->size;
}
}
void del_tree(Node*& a) {
if (a->son[0])del_tree(a->son[0]);
if (a->son[1])del_tree(a->son[1]);
delete a; a = NULL;
}
题目链接:[HDU3726] Graph and Queries
知识点:Treap树+并查集+离线逆序操作+启发式合并
搞来搞去发现最后是Treap树的remove操作写错了…抄了一下别人的remove
#include<cstdio>
#include<cstdlib>
#include<ctime>
using namespace std;
const int maxn = 2e4 + 5;
const int maxm = 6e4 + 5;
const int maxc = 6e5 + 5;
struct edge {
int u, v;
}e[maxm];
struct CMD {
int ch[5], x, k;
}cmd[maxc];
int fa[maxn], val[maxn], vis[maxm];
int n, m, kase;
struct Node{
Node* son[2];
int rank, key, size, id;
void update(){
size = 1;
if (son[0] != NULL) size += son[0]->size;
if (son[1] != NULL) size += son[1]->size;
}
};
Node* root[maxn];
int findfa(int x){
return (x == fa[x]) ? x : fa[x] = findfa(fa[x]);
}
inline void rotate(Node*& o, int d) {
Node* k = o->son[d ^ 1];
o->son[d ^ 1] = k->son[d];
k->son[d] = o; o = k;
k->son[d]->update();
k->update();
}
void insert(Node*& o, int x, int id) {
if (o == NULL) { o = new Node(); o->son[0] = o->son[1] = NULL; o->key = x; o->rank = rand(); o->id = id; }
else{
int d = (x < o->key || (x == o->key && id < o->id)) ? 0 : 1;
insert(o->son[d], x, id);
if (o->son[d]->rank > o->rank) rotate(o, d ^ 1);
}
o->update();
}
void remove(Node*& o, int x, int id){
if (o == NULL) return;
int d = ((x < o->key || (x == o->key && id < o->id)) ? 0 : 1);
if (o->key == x && id == o->id){
Node* tmp = o;
if (o->son[0] == NULL) { o = o->son[1]; delete tmp; }
else if (o->son[1] == NULL) { o = o->son[0]; delete tmp; }
else{
int d2 = (o->son[0]->rank > o->son[1]->rank) ? 1 : 0;
rotate(o, d2); remove(o->son[d2], x, id);
}
}
else remove(o->son[d], x, id);
if (o != NULL) o->update();
}
int kth(Node* o, int k) {
if (o == NULL || k <= 0 || k > o->size)return 0;
int s = o->son[1] == NULL ? 0 : o->son[1]->size;
if (k == s + 1)return o->key;
else if (k <= s)return kth(o->son[1], k);
else return kth(o->son[0], k - s - 1);
}
void merge(int id, Node*& a, Node*& b) {
if (b->son[0] != NULL) merge(id, a, b->son[0]);
if (b->son[1] != NULL) merge(id, a, b->son[1]);
fa[b->id] = id; insert(a, b->key, b->id); delete b; b = NULL;
}
int main(void) {
while (~scanf("%d %d", &n, &m) && n && m) {
for (int i = 1; i <= n; ++i) scanf("%d", &val[i]);
for (int i = 1; i <= m; ++i) {
scanf("%d %d", &e[i].u, &e[i].v);
vis[i] = 0;
}
char s[10]; int tot = 0;
for (;; tot++) {
scanf("%s", cmd[tot].ch);
if (cmd[tot].ch[0] == 'E') break;
if (cmd[tot].ch[0] == 'D') {
scanf("%d", &cmd[tot].x);
vis[cmd[tot].x] = 1;
}
else
scanf("%d %d", &cmd[tot].x, &cmd[tot].k);
if (cmd[tot].ch[0] == 'C') {
int x = cmd[tot].x, v = cmd[tot].k;
cmd[tot].k = val[x];
val[x] = v;
}
}
for (int i = 1; i <= n; ++i) {
root[i] = NULL;
insert(root[i], val[i], i);
fa[i] = i;
}
for (int i = 1; i <= m; ++i) {
if (vis[i])continue;
int tmpa = findfa(e[i].u), tmpb = findfa(e[i].v);
if (tmpa == tmpb) continue;
if (root[tmpa]->size > root[tmpb]->size) merge(tmpa, root[tmpa], root[tmpb]);
else merge(tmpb, root[tmpb], root[tmpa]);
}
double ans = 0, cnt = 0;
for (int i = tot - 1; i >= 0; --i) {
if (cmd[i].ch[0] == 'D') {
int fa1 = findfa(e[cmd[i].x].u), fa2 = findfa(e[cmd[i].x].v);
if (fa1 == fa2) continue;
if (root[fa1]->size > root[fa2]->size)
merge(fa1, root[fa1], root[fa2]);
else
merge(fa2, root[fa2], root[fa1]);
}
else if (cmd[i].ch[0] == 'Q') {
cnt++;
ans += kth(root[findfa(cmd[i].x)], cmd[i].k);
}
else {
int rt = findfa(cmd[i].x);
remove(root[rt], val[cmd[i].x], cmd[i].x);
insert(root[rt], cmd[i].k, cmd[i].x);
val[cmd[i].x] = cmd[i].k;
}
}
if (cnt == 0)cnt++;
printf("Case %d: %.6lf\n", ++kase, ans / cnt);
}
return 0;
}
右转Splay树的题解(第三题):https://blog.csdn.net/TK_wang_/article/details/108751229