实现简单的自平衡二叉搜索树(AVL)
在实现二叉搜索树(BST)的基础上增加高度平衡旋转
结构体Node:左结点left、右结点right、数据data、高度height
类BST:根结点root
主要功能有:自平衡旋转、添加、删除、查询、遍历(前序遍历、中序遍历、后序遍历、层序遍历、深度优先搜索、广度优先搜索)、高度计算
时间复杂度:
1、自平衡旋转
比较左右子树根结点高度,时间复杂度为O(1)
2、添加
时间复杂度为O(logn),最多两次单旋转(同一个地方)
3、删除
时间复杂度为O(logn),多次旋转,一直到根都有可能
简单描述一下二叉搜索树的删除过程,可以分为以下四种情况
a) 要删除的节点没有左右孩子
b) 要删除的节点只有左孩子(用左子结点代替)
c) 要删除的节点只有右孩子(用右子结点代替)
d) 要删除的节点有左右孩子(用后继结点代替)
我这里不直接用后继结点替代被删除结点,而是根据左右两颗子树高度,选择前继结点或者后继结点,可无需被删除结点以上的旋转操作
因此只有两种情况:
a) 要删除的节点没有左右孩子
b) 要删除的节点有孩子(根据左右两颗子树高度,选择前继结点或者后继结点)
4、查询
时间复杂度为O(logn)
5、遍历
遍历全结点,时间复杂度为O(n)
6、高度计算
因为AVL树维护了结点高度,只需返回根结点高度,时间复杂度为O(1)
总结:
1、二叉树算法的核心是递归,实现二叉树的任何功能都可通过递归实现
2、前序遍历是dfs的一种
3、BFS就是层序遍历,可以用队列的特点来实现
4、中序遍历是二叉搜索树结点从小到大排序
5、高度计算可以用前序遍历来实现
6、递归处理每个分支的一半,跟二分法类似,属于减治法或者分治法,用分治法的主定理来算T(n)=aT(n/b)+f(n),f(n)∈O(nd),这里a=1,b=2,d=0,a=bd,因此T(n)∈O(ndlogn)=O(logn)
7、AVL树通过维护结点高度来计算平衡因子,不需要每层递归都计算左右子树高度(O(n)),不然会导致添加和删除结点的时间复杂度都为O(nlogn)
#include<iostream>
#include<algorithm>//借用max函数
#include<vector>
using namespace std;
template<typename T>
struct Node {
Node<T>* left;
Node<T>* right;
T data;
int height;
Node(T val) :data(val), height(1),left(nullptr), right(nullptr) {}
};
template<typename T>
class AVL {
private:
Node<T>* root;
//自平衡旋转:右单转、左单转、左右双转、右左双转
Node<T>* insert_rotate(Node<T>* root, T val, int lr) {
if (lr == 0) {
if (rheight(root->left) - rheight(root->right) == 2) {
if (root->left->data >val)
root = rightRotate(root);
else root = lrRotate(root);
}
}
else {
if (rheight(root->right) - rheight(root->left) == 2) {
if (root->right->data < val)
root = leftRotate(root);
else root = rlRotate(root);
}
}
return root;
}
Node<T>* delete_rotate(Node<T>* root, T val, int lr) {
if (lr == 1) {
if (rheight(root->left) - rheight(root->right) == 2) {
if (root->left->data > val)
root = rightRotate(root);
else root = lrRotate(root);
}
}
else {
if (rheight(root->right) - rheight(root->left) == 2) {
if (root->right->data < val)
root = leftRotate(root);
else root = rlRotate(root);
}
}
return root;
}
Node<T>* rightRotate(Node<T>* root) {
Node<T>* rleft = root->left;
root->left = rleft->right;
rleft->right = root;
root->height = max(rheight(root->left), rheight(root->right)) + 1;
rleft->height = max(rheight(rleft->left), rheight(rleft->right)) + 1;
return rleft;
}
Node<T>* leftRotate(Node<T>* root) {
Node<T>* rright = root->right;
root->right = rright->left;
rright->left = root;
root->height = max(rheight(root->left), rheight(root->right)) + 1;
rright->height = max(rheight(rright->left), rheight(rright->right)) + 1;
return rright;
}
Node<T>* lrRotate(Node<T>* root) {
root->left = leftRotate(root->left);
return rightRotate(root);
}
Node<T>* rlRotate(Node<T>* root) {
root->right = rightRotate(root->right);
return leftRotate(root);
}
//插入辅助函数
Node<T>* insert(Node<T>* root, T val) {
if (root) {
if (root->data > val) {
root->left = insert(root->left, val);
root->height = max(rheight(root->left), rheight(root->right)) + 1;
root = insert_rotate(root, val, 0);
return root;
}
if (root->data < val) {
root->right = insert(root->right, val);
root->height = max(rheight(root->left), rheight(root->right)) + 1;
root = insert_rotate(root, val, 1);
return root;
}
return nullptr;
}
return new Node<T>(val);
}
//删除辅助函数
Node<T>* del(Node<T>* root, T val) {
if (root) {
if (root->data > val) {
root->left = del(root->left, val);
if(root->left)
root->left->height= max(rheight(root->left->left), rheight(root->left->right)) + 1;
root->height = max(rheight(root->left), rheight(root->right)) + 1;
if (root->right) {
if (root->right->right)
root = delete_rotate(root, root->right->right->data, 0);
}
}
else if (root->data < val) {
root->right = del(root->right, val);
if (root->right)
root->right->height = max(rheight(root->right->left), rheight(root->right->right)) + 1;
root->height = max(rheight(root->left), rheight(root->right)) + 1;
if (root->left) {
if (root->left->left)
root = delete_rotate(root, root->left->left->data, 1);
}
}
else {
Node<T>* r = nullptr;
if (root->left || root->right) {
r = root;
if (rheight(root->left) > rheight(root->right)) {
Node<T>* left = r->left;
int i;
for (i = 0; left->right != nullptr; i++) {
left = left->right;
if (i)
r = r->right;
else r = r->left;
}
if (i) {
r->right = left->left;
if (root)
del(root, left->data);
left->right = root->right;
left->left = root->left;
}
else left->right = root->right;
r = left;
}
else {
Node<T>* right = r->right;
int i;
for (i = 0; right->left != nullptr;i++) {
right = right->left;
if (i)
r = r->left;
else r = r->right;
}
if (i) {
r->left = right->right;
if (root)
del(root, right->data);
right->right = root->right;
right->left = root->left;
}
else right->left = root->left;
r = right;
}
}
root->left = nullptr;
root->right = nullptr;
root = r;
}
return root;
}
return nullptr;
}
//查找辅助函数
bool find(Node<T>* root, T val) {
if (root) {
if (root->data == val)
return true;
if (root->data > val)
return find(root->left, val);
else return find(root->right, val);
}
return false;
}
//遍历辅助函数:前序遍历、中序遍历、后序遍历
void preOrder(Node<T>* root) {
if (root) {
cout << root->data;
preOrder(root->left);
preOrder(root->right);
}
}
void inOrder(Node<T>* root) {
if (root) {
inOrder(root->left);
cout << root->data;
inOrder(root->right);
}
}
void postOrder(Node<T>* root) {
if (root) {
postOrder(root->left);
postOrder(root->right);
cout << root->data;
}
}
int rheight(Node<T>* node) {
if (node)
return node->height;
return 0;
}
public:
AVL(T val) { root = new Node<T>(val); }
~AVL() {}
//添加
bool insert(T val) {
Node<T>* p = insert(root, val);
if (p) {
root = p;
return true;
}
return false;
}
//删除
bool del(T val) {
Node<T>* p = del(root, val);
root->height = max(rheight(root->left), rheight(root->right)) + 1;
if (p) {
root = p;
return true;
}
return false;
}
//查询
bool find(T val) {
return find(root, val);
}
//遍历:前序遍历(dfs)、中序遍历、后序遍历、层序遍历(bfs)
//前序遍历
void preOrder() {
preOrder(root);
cout << endl;
}
//中序遍历
void inOrder() {
inOrder(root);
cout << endl;
}
//后序遍历
void postOrder() {
postOrder(root);
cout << endl;
}
//层序遍历
void levelOrder() {
if (root) {
vector<Node<T>*> nodes;
nodes.push_back(root);
int n = 1;
for (int i = 0; i < n; i++) {
cout << nodes[i]->data;
if (nodes[i]->left) {
nodes.push_back(nodes[i]->left);
n++;
}
if (nodes[i]->right) {
nodes.push_back(nodes[i]->right);
n++;
}
}
}
cout << endl;
}
//深度优先搜索
void dfs() { preOrder(); }
//广度优先搜索
void bfs() { levelOrder(); }
int height() { return root->height; }
};
int main() {
AVL<int> avl(4);
//添加测试
cout << "----------添加测试----------" << endl;
cout << "前序遍历:";
avl.preOrder();
avl.insert(2);
cout << "前序遍历:";
avl.preOrder();
avl.insert(6);
cout << "前序遍历:";
avl.preOrder();
avl.insert(1);
cout << "前序遍历:";
avl.preOrder();
avl.insert(3);
cout << "前序遍历:";
avl.preOrder();
avl.insert(7);
cout << "前序遍历:";
avl.preOrder();
avl.insert(5);
cout << "前序遍历:";
avl.preOrder();
avl.insert(8);
cout << "前序遍历:";
avl.preOrder();
/*
4
↙ ↘
2 6
↙ ↘ ↙ ↘
1 3 5 7
↘
8
*/
//高度平衡测试
cout << "----------高度平衡测试----------" << endl;
avl.insert(9);
cout << "前序遍历:";
avl.preOrder();
/*
4
↙ ↘
2 6
↙ ↘ ↙ ↘
1 3 5 8
↙ ↘
7 9
*/
//遍历测试
cout << "----------遍历测试----------" << endl;
cout << "前序遍历:";
avl.preOrder();
cout << "中序遍历:";
avl.inOrder();
cout << "后序遍历:";
avl.postOrder();
cout << "层序遍历:";
avl.levelOrder();
cout << "深度优先搜索:";
avl.dfs();
cout << "广度优先搜索:";
avl.bfs();
//高度计算测试
cout << "----------高度计算测试----------" << endl;
cout << "height:" << avl.height() << endl;
//删除测试
cout << "----------删除测试----------" << endl;
cout << "删除4:" << avl.del(4) << endl;
cout << "前序遍历:";
avl.preOrder();
//查询测试
cout << "----------查询测试----------" << endl;
cout << "查找7:" << avl.find(7) << endl;
}
附上测试用例
如有不对,请大佬指出orz