在完成了左旋和右旋操作之后,我们要实现调用两种单旋操作的逻辑,也就是调整操作。
在插入之后,一共有 44 种触发旋转的情况,分别为 LL 型、LR 型、RR 型和 RL 型。
通过旋转的名称,可以很直观的想到其对应的不平衡的情况——比如 LR 型,就意味着左子树(L)的右子树(R)的元素个数过大。
还记得我们之前讲到过的,SBTree 的平衡条件么?一共有两个:
a: size[right[t]] ≥ max(size[left[left[t]]], size[right[left[t]]])
a:size[right[t]]≥max(size[left[left[t]]],size[right[left[t]]])
b: size[left[t]] ≥ max(size[left[right[t]]], size[right[right[t]]])
b:size[left[t]]≥max(size[left[right[t]]],size[right[right[t]]])
对于 LL 和 LR 型,违反了平衡条件 aa;对于 RR 和 RL 型,则违反了平衡条件 bb。
对应的,LL 型和 LR 型都一定能保证平衡条件 bb 的成立;RR 型和 RL 型也都能一定保证平衡条件 aa 的成立。
我们只需要检查其中一半的情况即可,来避免无谓的判断。我们可以将算法伪代码简化如下:
如果在处理左子树更高的情况:
LL 型:右旋 tt。
LR 型:左旋 tt 的左子树,再右旋 tt。
如果在处理右子树更高的情况:
RR 型:左旋 tt。
RL 型:右旋 tt 的右子树,再左旋 tt。
递归调整左子树其中左子树的左子树更高的情况
递归调整右子树其中右子树的右子树更高的情况
递归调整当前子树其中左子树更高的情况
递归调整当前子树其中右子树更高的情况
为什么可以不考虑右子树其中右子树的左子树更高的情况呢?因为这种情况在第 66 步已经被处理了。左子树其中左子树的右子树更高的情况也类似。因此我们可以通过和之前介绍的相比更简洁、高效的调整算法实现对 SBTree 的维护。
#include <iostream>
using namespace std;
class SBTNode {
public:
int data, size, value;
SBTNode * lchild, * rchild, * father;
SBTNode(int init_data, int init_size = 0, SBTNode * init_father = NULL);
~SBTNode();
void insert(int value);
SBTNode * search(int value);
SBTNode * predecessor();
SBTNode * successor();
void remove_node(SBTNode * delete_node);
bool remove(int value);
int select(int k);
};
class BinaryTree {
private:
SBTNode * root;
public:
BinaryTree();
~BinaryTree();
void insert(int value);
bool find(int value);
bool remove(int value);
int select(int k);
};
SBTNode ZERO(0);
SBTNode * ZPTR = &ZERO;
SBTNode::SBTNode(int init_data, int init_size, SBTNode * init_father) {
data = init_data;
size = init_size;
lchild = ZPTR;
rchild = ZPTR;
father = init_father;
}
SBTNode::~SBTNode() {
if (lchild != ZPTR) {
delete lchild;
}
if (rchild != ZPTR) {
delete rchild;
}
}
SBTNode * left_rotate(SBTNode * node) {
SBTNode * temp = node->rchild;
node->rchild = temp->lchild;
temp->lchild->father = node;
temp->lchild = node;
temp->father = node->father;
node->father = temp;
temp->size = node->size;
node->size = node->lchild->size + node->rchild->size + 1;
return temp;
}
SBTNode * right_rotate(SBTNode * node) {
SBTNode * temp = node->lchild;
node->lchild = temp->rchild;
temp->rchild->father = node;
temp->rchild = node;
temp->father = node->father;
node->father = temp;
temp->size = node->size;
node->size = node->lchild->size + node->rchild->size + 1;
return temp;
}
SBTNode * maintain(SBTNode * node, bool flag) {
if (flag == false) {
if (node->lchild->lchild->size > node->rchild->size) {
node = right_rotate(node);
} else if (node->lchild->rchild->size > node->rchild->size) {
node->lchild = left_rotate(node->lchild);
node = right_rotate(node);
} else {
return node;
}
} else {
if (node->rchild->rchild->size > node->lchild->size) {
node = left_rotate(node);
} else if (node->rchild->lchild->size > node->lchild->size) {
node->rchild = right_rotate(node->rchild);
node = left_rotate(node);
} else {
return node;
}
}
node->lchild = maintain(node->lchild, false);
node->rchild = maintain(node->rchild, true);
node = maintain(node, false);
node = maintain(node, true);
return node;
}
SBTNode * insert(SBTNode * node, int value) {
if (value == node->data) {
return node;
} else {
node->size++;
if (value > node->data) {
if (node->rchild == ZPTR) {
node->rchild = new SBTNode(value, 1, node);
} else {
node->rchild = insert(node->rchild, value);
}
} else {
if (node->lchild == ZPTR) {
node->lchild = new SBTNode(value, 1, node);
} else {
node->lchild = insert(node->lchild, value);
}
}
}
return maintain(node, value > node->data);
}
SBTNode * SBTNode::search(int value) {
if (data == value) {
return this;
} else if (value > data) {
if (rchild == ZPTR) {
return ZPTR;
} else {
return rchild->search(value);
}
} else {
if (lchild == ZPTR) {
return ZPTR;
} else {
return lchild->search(value);
}
}
}
SBTNode * SBTNode::predecessor() {
SBTNode * temp = lchild;
while (temp != ZPTR && temp->rchild != ZPTR) {
temp = temp->rchild;
}
return temp;
}
SBTNode * SBTNode::successor() {
SBTNode * temp = rchild;
while (temp != ZPTR && temp->lchild != ZPTR) {
temp = temp->lchild;
}
return temp;
}
void SBTNode::remove_node(SBTNode * delete_node) {
SBTNode * temp = ZPTR;
if (delete_node->lchild != ZPTR) {
temp = delete_node->lchild;
temp->father = delete_node->father;
delete_node->lchild = ZPTR;
}
if (delete_node->rchild != ZPTR) {
temp = delete_node->rchild;
temp->father = delete_node->father;
delete_node->rchild = ZPTR;
}
if (delete_node->father->lchild == delete_node) {
delete_node->father->lchild = temp;
} else {
delete_node->father->rchild = temp;
}
temp = delete_node;
while (temp != NULL) {
temp->size--;
temp = temp->father;
}
delete delete_node;
}
bool SBTNode::remove(int value) {
SBTNode * delete_node, * current_node;
current_node = search(value);
if (current_node == ZPTR) {
return false;
}
size--;
if (current_node->lchild != ZPTR) {
delete_node = current_node->predecessor();
} else if (current_node->rchild != ZPTR) {
delete_node = current_node->successor();
} else {
delete_node = current_node;
}
current_node->data = delete_node->data;
remove_node(delete_node);
return true;
}
int SBTNode::select(int k) {
int rank=lchild->size+1;
if(rank==k){
return data;
}
else if(k<rank){
return lchild->select(k);
}else{
return rchild->select(k-rank);
}
}
BinaryTree::BinaryTree() {
root = NULL;
}
BinaryTree::~BinaryTree() {
if (root != NULL) {
delete root;
}
}
void BinaryTree::insert(int value) {
if (root == NULL) {
root = new SBTNode(value, 1);
} else {
root = ::insert(root, value);
}
}
bool BinaryTree::find(int value) {
if (root->search(value) == NULL) {
return false;
} else {
return true;
}
}
bool BinaryTree::remove(int value) {
return root->remove(value);
}
int BinaryTree::select(int k) {
return root->select(k);
}
int main() {
BinaryTree binarytree;
int arr[10] = { 8, 9, 10, 3, 2, 1, 6, 4, 7, 5 };
for (int i = 0; i < 10; i++) {
binarytree.insert(arr[i]);
}
int k;
cin >> k;
cout << binarytree.select(k) << endl;
return 0;
}