看了,数据结构与算法分析,发现 AVL 平衡树相对于红黑树,更加简单明了,代码如下:
#ifndef AVLTREE_H
#define AVLTREE_H
#include <vector>
template<typename K, typename V>
class AVLNode{
public:
AVLNode(){
left = nullptr;
right = nullptr;
parent = nullptr;
height = 1;
layer = 0;
}
AVLNode(K k, V v){
left = nullptr;
right = nullptr;
parent = nullptr;
height = 1;
key = k;
value = v;
layer = 0;
}
AVLNode *left;
AVLNode *right;
AVLNode *parent;
K key;
V value;
int height;
int layer;
};
template<typename K, typename V>
class AVLTree
{
public:
AVLTree(int compare(K a, K b),int dump(int height, K &a, V &v))
{
root = nullptr;
this->compare = compare;
this->dump_t = dump;
}
~AVLTree(){
deleteNode(root);
}
void deleteNode(AVLNode<K,V> *node){
if(node == nullptr){
return;
}
deleteNode(node->left);
deleteNode(node->right);
delete node;
}
void insert(K key, V value){
root = insertNode(root, key, value);
}
int max(int a, int b){
if(a > b){
return a;
}else{
return b;
}
}
int height(AVLNode<K,V> *node){
if(node == nullptr){
return 0;
}
return node->height;
}
int updateHeight(AVLNode<K,V> *node)
{
int h = max(height(node->left),height(node->right)) + 1;
return h;
}
AVLNode<K,V> *signalRotateLeftChild(AVLNode<K,V> *k2){
AVLNode<K,V> *k1 = k2->left;
k2->left = k1->right;
if(k2->left){
k2->left->parent = k2;
}
k1->right = k2;
k1->parent = k2->parent;
k2->parent = k1;
k2->height = updateHeight(k2);
k1->height = updateHeight(k1);
return k1;
}
AVLNode<K,V> *doubleRotateLeftChild(AVLNode<K,V> *k2){
k2->left = signalRotateRightChild(k2->left);
if(k2->left){
k2->left->parent = k2;
}
return signalRotateLeftChild(k2);
}
AVLNode<K,V> *signalRotateRightChild(AVLNode<K,V> *k2){
AVLNode<K,V> *k1 = k2->right;
k2->right = k1->left;
if(k2->right){
k2->right->parent = k2;
}
k1->left = k2;
k1->parent = k2->parent;
k2->parent = k1;
k2->height = updateHeight(k2);
k1->height = updateHeight(k1);
return k1;
}
AVLNode<K,V> *doubleRotateRightChild(AVLNode<K,V> *k2){
k2->right = signalRotateLeftChild(k2->right);
if(k2->right){
k2->right->parent = k2;
}
return signalRotateRightChild(k2);
}
AVLNode<K,V> *insertNode(AVLNode<K,V> *node, K key, V value){
if(node == nullptr){
return new AVLNode<K,V>(key,value);
}
int result = compare(key, node->key);
if(result < 0){
node->left = insertNode(node->left, key, value);
if(node->left){
node->left->parent = node;
}
}else if(result > 0){
node->right = insertNode(node->right, key, value);
if(node->right){
node->right->parent = node;
}
}else{
return node;
}
int leftHeight = height(node->left);
int rightHeight = height(node->right);
if(leftHeight - rightHeight == 2){
if(compare(key, node->left->key) < 0){
node = signalRotateLeftChild(node);
}else{
node = doubleRotateLeftChild(node);
}
}else if(rightHeight - leftHeight == 2){
if(compare(key, node->right->key) > 0){
node = signalRotateRightChild(node);
}else{
node = doubleRotateRightChild(node);
}
}
node->height = updateHeight(node);
return node;
}
AVLNode<K,V> *findMin(AVLNode<K,V> *node){
if(node == nullptr) {
return node;
}
if(node->left == nullptr) {
return node;
}
return findMin(node->left);
}
void remove(K key){
root = remove(root,key);
}
AVLNode<K,V> *remove(AVLNode<K,V> *node ,K key){
if(node == nullptr){
return node;
}
int result = compare(key,node->key);
if(result == 0) {
// 右枝还有叶子,找到右枝最小叶子,代替要删除的节点,
// 否则直接删除此节点
if(node->right){
AVLNode<K,V> *rightMin = findMin(node->right);
if(rightMin == nullptr) {
if(node->parent){
if(node->parent->left == node){
node->parent->left = node->right;
}else{
node->parent->right = node->right;
}
}
AVLNode<K,V> *temp;
temp = node->right;
if(temp){
temp->parent = node->parent;
}
delete node;
node = temp;
}else {
node->key = rightMin->key;
node->right = remove(node->right, node->key);
if(node->right){
node->right->parent = node;
}
}
}else{
AVLNode<K,V> *temp;
temp = node->left;
if(temp){
temp->parent = node->parent;
}
delete node;
node = temp;
}
}else if(result < 0) {
node->left = remove(node->left, key);
if(node->left){
node->left->parent = node;
}
}else {
node->right = remove(node->right, key);
if(node->right){
node->right->parent = node;
}
}
if(node == nullptr){
return node;
}
if(height(node->left) - height(node->right) == 2) {
int k1h = height(node->left->left);
int k2h = height(node->left->right);
if(k1h > k2h) {
node = signalRotateLeftChild(node);
}else {
node = doubleRotateLeftChild(node);
}
}else if(height(node->right) - height(node->left) == 2) {
int k1h = height(node->right->left);
int k2h = height(node->right->right);
if(k2h > k1h) {
node = signalRotateRightChild(node);
}else {
node = doubleRotateRightChild(node);
}
}
node->height = max( height(node->right) , height(node->left))+1;
return node;
}
AVLNode<K,V> *findMax(AVLNode<K,V> *node){
if(node == nullptr) {
return node;
}
if(node->right == nullptr) {
return node;
}
return findMax(node->right);
}
//这个删除实现可能更好,更加容易理解
void remove2(K key){
root = remove2(root,key);
}
AVLNode<K,V> *remove2(AVLNode<K,V> *node ,K key){
if(node == nullptr){
return node;
}
int result = compare(key,node->key);
if(result == 0) {
if(node->right == nullptr && node->left == nullptr){
delete node;
node = nullptr;
}else if(height(node->right) > height(node->left)){
AVLNode<K,V> *find = findMin(node->right);
node->key = find->key;
node->right = remove2(node->right, node->key);
if(node->right){
node->right->parent = node;
}
}else{
AVLNode<K,V> *find = findMax(node->left);
node->key = find->key;
node->left = remove2(node->left, node->key);
if(node->left){
node->left->parent = node;
}
}
}else if(result < 0) {
node->left = remove2(node->left, key);
if(node->left){
node->left->parent = node;
}
}else {
node->right = remove2(node->right, key);
if(node->right){
node->right->parent = node;
}
}
if(node == nullptr){
return node;
}
if(height(node->left) - height(node->right) == 2) {
int k1h = height(node->left->left);
int k2h = height(node->left->right);
if(k1h > k2h) {
node = signalRotateLeftChild(node);
}else {
node = doubleRotateLeftChild(node);
}
}else if(height(node->right) - height(node->left) == 2) {
int k1h = height(node->right->left);
int k2h = height(node->right->right);
if(k2h > k1h) {
node = signalRotateRightChild(node);
}else {
node = doubleRotateRightChild(node);
}
}
node->height = max( height(node->right) , height(node->left))+1;
return node;
}
void updateLayer()
{
updateLayerNode(root,0);
}
void updateLayerNode(AVLNode<K,V> *node, int layer){
if(node == nullptr){
return;
}
node->layer = layer;
updateLayerNode(node->left,layer+1);
updateLayerNode(node->right,layer+1);
}
void dumpSpace(int h, int layer)
{
int count = 1;
for(int i = 0;i < h+1 - layer;i++){
count *= 2;
}
count = count-4;
for(int i = 0;i < count;i++){
printf(" ");
}
}
void dumpSpace2(int h, int layer)
{
dumpSpace(h,layer);
printf(" ");
}
void dump3(){
if(root == nullptr){
return;
}
updateLayer();
std::vector<AVLNode<K,V> *>list;
int count = 1;
list.push_back(root);
int layer = 1;
int layerNum = 0;
int layerCount = 1;
while(list.size() > 0 && count > 0){
AVLNode<K,V> *node = list.at(0);
list.erase(list.begin());
if(node){
count--;
if(dump_t){
dumpSpace(root->height, layerNum);
dump_t(node->height,node->key,node->value);
dumpSpace2(root->height, layerNum);
}
if(node->left){
count++;
list.push_back(node->left);
}else{
list.push_back(nullptr);
}
if( node->right){
count++;
list.push_back(node->right);
}else{
list.push_back(nullptr);
}
}else{
list.push_back(nullptr);
list.push_back(nullptr);
dumpSpace(root->height, layerNum);
printf("___,");
dumpSpace2(root->height, layerNum);
}
layerCount--;
if(layerCount == 0){
layerNum++;
layer *= 2;
layerCount = layer;
printf("\n");
}
}
printf("\n");
}
void dump(){
dumpNode(root, 0);
}
void dumpNode(AVLNode<K,V> *node, int layer){
if(!node){
return;
}
int n = node->height;
for(int i = 0;i <= n;i++){
printf(" ");
}
if(dump_t){
dump_t(node->height, node->key, node->value);
}
dumpNode(node->left,layer+1);
dumpNode(node->right,layer+1);
printf("\n");
}
void dump2(){
dumpNode2(root, 0);
}
void dumpNode2(AVLNode<K,V> *node, int layer){
if(!node){
return;
}
dumpNode2(node->left,layer+1);
if(dump_t){
dump_t(node->height, node->key, node->value);
}
dumpNode2(node->right,layer+1);
}
private:
AVLNode<K,V> *root;
int (*compare)(K a, K b);
int (*dump_t)(int height, K &a, V &v);
};
#endif // AVLTREE_H
简单测试代码:
#include "avltree.h"
int cmp(int a, int b){
return a - b;
}
int dump(int h, int &key, int &value)
{
printf("%03d,",key);
return 0;
}
void testAVL()
{
AVLTree<int, int> *tree = new AVLTree<int,int>(cmp,dump);
for(int i = 0;i < 20;i++){
printf("\n-- insert: %d\n",i);
tree->insert(i,i);
tree->dump3();
}
tree->dump3();
for(int i = 19;i >= 0;i--){
printf("\n-- remove: %d\n",i);
tree->remove(i);
tree->dump3();
}
tree->dump3();
}
int main(){
testAVL();
return 0;
}