实现简单的红黑树(RBT)
在实现二叉搜索树(AVL)的基础上增加红黑平衡旋转
性质:
1、根节点是黑色
2、插入节点是红的
3、红色节点的子节点是黑的
4、从任一节点到其每个叶子的所有路径都包含相同数目的黑色节点
结构体Node:左结点left、右结点right、数据data、颜色col
类RBT:根结点root
主要功能有:规则判断(红黑平衡旋转)、添加、删除、查询、遍历(前序遍历、中序遍历、后序遍历、层序遍历、深度优先搜索、广度优先搜索)、高度计算
时间复杂度:
1、规则判断(红黑平衡旋转)
比较左右子树红黑个数,时间复杂度为O(1)
2、添加
时间复杂度为O(logn),最多两次单旋转(同一个地方)
3、删除
时间复杂度为O(logn),最多三次单旋转(同一个地方)
简单描述一下二叉搜索树的删除过程,可以分为以下四种情况
a) 要删除的节点没有左右孩子
b) 要删除的节点只有左孩子(用左子结点代替)
c) 要删除的节点只有右孩子(用右子结点代替)
d) 要删除的节点有左右孩子(用后继结点代替)
我这里不直接用后继结点替代被删除结点,而是根据左右两颗子树高度,选择前继结点或者后继结点
a) 要删除的节点没有左右孩子
b) 要删除的节点有孩子(根据左右两颗子树高度,选择前继结点或者后继结点)
4、查询
时间复杂度为O(logn)
5、遍历
遍历全结点,时间复杂度为O(n)
6、高度计算
因为AVL树维护了结点高度,只需返回根结点高度,时间复杂度为O(1)
总结:
1、二叉树算法的核心是递归,实现二叉树的任何功能都可通过递归实现
2、前序遍历是dfs的一种
3、BFS就是层序遍历,可以用队列的特点来实现
4、中序遍历是二叉搜索树结点从小到大排序
5、高度计算可以用前序遍历来实现
6、递归处理每个分支的一半,跟二分法类似,属于减治法或者分治法,用分治法的主定理来算T(n)=aT(n/b)+f(n),f(n)∈O(nd),这里a=1,b=2,d=0,a=bd,因此T(n)∈O(ndlogn)=O(logn)
7、AVL树是高度平衡,插入删除性能低,搜索性能高,红黑树搜索性能相对较低,插入删除性能高。若搜索次数远远大于插入删除,应选择AVL树,若搜索次数和插入删除次数差不多,则选择红黑树
#include<iostream>
#include<algorithm>//借用max函数
#include<vector>
using namespace std;
enum color { BLACK, RED };
template<typename T>
struct Node {
Node<T>* left;
Node<T>* right;
T data;
color col;
Node(T val) :data(val), col(RED), left(nullptr), right(nullptr) {}
};
template<typename T>
class RBT {
private:
Node<T>* root;
Node<T>* rule(Node<T>* root, color befcol, T befval, T val, int lr) {
color aftcol;
T aftval;
if (lr == 0) {
if (root->left) {
aftcol = root->left->col;
aftval = root->left->data;
}
else {
aftcol = BLACK;
aftval = -1;
}
if (aftcol != befcol && befval == aftval) {
if (aftcol == BLACK) {
root->col = RED;
if (root->right) {
if (root->right->col == RED)
root->right->col = BLACK;
else root = rotate(root, val, 0);
}
else root = rotate(root, val, 1);
}
}
if (root->left) {
if (root->col == RED && root->left->col == RED) {
root->col = BLACK;
}
}
}
else {
if (root->right) {
aftcol = root->right->col;
aftval = root->right->data;
}
else {
aftcol = BLACK;
aftval = -1;
}
if (befcol != aftcol && befval == aftval) {
if (aftcol == BLACK) {
root->col = RED;
if (root->left) {
if (root->left->col == RED)
root->left->col = BLACK;
else root = rotate(root, val, 1);
}
else root = rotate(root, val, 1);
}
}
if (root->right) {
if (root->col == RED && root->right->col == RED) {
root->col = BLACK;
}
}
}
return root;
}
//自平衡旋转:右单转、左单转、左右双转、右左双转
Node<T>* rotate(Node<T>* root, T val, int lr) {
if (lr == 0) {
if (val < root->left->data)
root = rightRotate(root);
else root = lrRotate(root);
}
else {
if (val > root->right->data)
root = leftRotate(root);
else root = rlRotate(root);
}
return root;
}
Node<T>* rightRotate(Node<T>* root) {
if (root->right)
root->right->col = BLACK;
Node<T>* rleft = root->left;
root->left = rleft->right;
rleft->right = root;
return rleft;
}
Node<T>* leftRotate(Node<T>* root) {
Node<T>* rright = root->right;
root->right = rright->left;
rright->left = root;
return rright;
}
Node<T>* lrRotate(Node<T>* root) {
root->left = leftRotate(root->left);
root->left->col = BLACK;
root->left->left->col = RED;
return rightRotate(root);
}
Node<T>* rlRotate(Node<T>* root) {
root->right = rightRotate(root->right);
root->right->col = BLACK;
root->right->right->col = RED;
return leftRotate(root);
}
//插入辅助函数
Node<T>* insert(Node<T>* root, T val) {
if (root) {
color befcol;
T befval;
if (root->data > val) {
if (root->left) {
befcol = root->left->col;
befval = root->left->data;
}
else {
befcol = BLACK;
befval = -1;
}
root->left = insert(root->left, val);
return rule(root, befcol, befval, val, 0);;
}
if (root->data < val) {
if (root->right) {
befcol = root->right->col;
befval = root->right->data;
}
else {
befcol = BLACK;
befval = -1;
}
root->right = insert(root->right, val);
return rule(root, befcol, befval, val, 1);
}
return nullptr;
}
return new Node<T>(val);
}
//删除辅助函数
Node<T>* del(Node<T>* root, T val) { return nullptr; }
//查找辅助函数
bool find(Node<T>* root, T val) {
if (root) {
if (root->data == val)
return true;
if (root->data > val)
return find(root->left, val);
else return find(root->right, val);
}
return false;
}
//遍历辅助函数
void preOrder(Node<T>* root) {
if (root) {
cout << root->data;
preOrder(root->left);
preOrder(root->right);
}
}
void inOrder(Node<T>* root) {
if (root) {
inOrder(root->left);
cout << root->data;
inOrder(root->right);
}
}
void postOrder(Node<T>* root) {
if (root) {
postOrder(root->left);
postOrder(root->right);
cout << root->data;
}
}
//高度计算辅助函数
int height(Node<T>* root) {
if (root) {
return max(height(root->left), height(root->right)) + 1;
}
return 0;
}
public:
RBT(T val) { root = new Node<T>(val); root->col = BLACK; }
~RBT() {}
//添加
bool insert(T val) {
Node<T>* p = insert(root, val);
if (p) {
root = p;
root->col = BLACK;
return true;
}
return false;
}
//删除
bool del(T val) {
return false;
}
//查询
bool find(T val) {
return find(root, val);
}
//遍历:前序遍历(dfs)、中序遍历、后序遍历、层序遍历(bfs)
//前序遍历
void preOrder() {
preOrder(root);
cout << endl;
}
//中序遍历
void inOrder() {
inOrder(root);
cout << endl;
}
//后序遍历
void postOrder() {
postOrder(root);
cout << endl;
}
//层序遍历
void levelOrder() {
if (root) {
vector<Node<T>*> nodes;
nodes.push_back(root);
int n = 1;
for (int i = 0; i < n; i++) {
cout << nodes[i]->data;
if (nodes[i]->left) {
nodes.push_back(nodes[i]->left);
n++;
}
if (nodes[i]->right) {
nodes.push_back(nodes[i]->right);
n++;
}
}
}
cout << endl;
}
//深度优先搜索
void dfs() { preOrder(); }
//广度优先搜索
void bfs() { levelOrder(); }
//高度计算
int height() {
return height(root);
}
};
int main() {
RBT<int> rbt(4);
//添加测试
cout << "----------添加测试----------" << endl;
cout << "前序遍历:";
rbt.preOrder();
rbt.insert(2);
cout << "前序遍历:";
rbt.preOrder();
rbt.insert(6);
cout << "前序遍历:";
rbt.preOrder();
rbt.insert(1);
cout << "前序遍历:";
rbt.preOrder();
rbt.insert(3);
cout << "前序遍历:";
rbt.preOrder();
rbt.insert(7);
cout << "前序遍历:";
rbt.preOrder();
rbt.insert(5);
cout << "前序遍历:";
rbt.preOrder();
rbt.insert(8);
cout << "前序遍历:";
rbt.preOrder();
/*
4
↙ ↘
2 6
↙ ↘ ↙ ↘
1 3 5 7
↘
8
*/
//红黑平衡测试
cout << "----------红黑平衡测试----------" << endl;
rbt.insert(9);
cout << "前序遍历:";
rbt.preOrder();
/*
4
↙ ↘
2 6
↙ ↘ ↙ ↘
1 3 5 8
↙ ↘
7 9
*/
//遍历测试
cout << "----------遍历测试----------" << endl;
cout << "前序遍历:";
rbt.preOrder();
cout << "中序遍历:";
rbt.inOrder();
cout << "后序遍历:";
rbt.postOrder();
cout << "层序遍历:";
rbt.levelOrder();
cout << "深度优先搜索:";
rbt.dfs();
cout << "广度优先搜索:";
rbt.bfs();
//高度计算测试
cout << "----------高度计算测试----------" << endl;
cout << "height:" << rbt.height() << endl;
//查询测试
cout << "----------查询测试----------" << endl;
cout << "查找7:" << rbt.find(7) << endl;
}
附上测试用例
如有不对,请大佬指出orz