二叉排序树(AVL树)源码
AVLTree.h源码
#pragma once
#include "stdafx.h"
#include <vector>
using namespace std;
class AVLNode
{
public:
float nodeValue;
int treeHeight;
AVLNode* lChild;
AVLNode* rChild;
AVLNode(float v){lChild = NULL; rChild = NULL; nodeValue = 0; treeHeight = 0; nodeValue = v;}
int getLeftHeight(){return lChild != NULL?lChild->treeHeight:0;}
int getRightHeight(){return rChild != NULL?rChild->treeHeight:0;}
void refreshHeight()
{
int leftHeight = getLeftHeight();
int rightHeight = getRightHeight();
treeHeight = (leftHeight>rightHeight?leftHeight:rightHeight)+1;
}
};
class AVLTree
{
private:
AVLNode* root;
void lRotation(AVLNode* &tree);
void lrRotation(AVLNode* &tree);
void rRotation(AVLNode* &tree);
void rlRotation(AVLNode* &tree);
void rebalance(AVLNode* &tree);
void add(AVLNode* &tree, AVLNode* node);
void remove(AVLNode* &tree, AVLNode* node);
public:
AVLTree();
void add(AVLNode* node);
void remove(AVLNode* node);
vector<AVLNode*> getSortedList();
};
AVLTree.cpp源码
#include "stdafx.h"
#include "AVLTree.h"
AVLTree::AVLTree()
{
root = NULL;
}
void AVLTree::lRotation(AVLNode* &tree)
{
AVLNode* newTree = tree->rChild;
tree->rChild = newTree->lChild;
newTree->lChild = tree;
tree = newTree;
tree->lChild->refreshHeight();
tree->refreshHeight();
}
void AVLTree::rlRotation(AVLNode* &tree)
{
rRotation(tree->rChild);
lRotation(tree);
}
void AVLTree::rRotation(AVLNode* &tree)
{
AVLNode* newTree = tree->lChild;
tree->lChild = newTree->rChild;
newTree->rChild = tree;
tree = newTree;
tree->rChild->refreshHeight();
tree->refreshHeight();
}
void AVLTree::lrRotation(AVLNode* &tree)
{
lRotation(tree->lChild);
rRotation(tree);
}
void AVLTree::rebalance(AVLNode* &tree)
{
if(tree->getLeftHeight() - tree->getRightHeight() >=2)
{
int llHeight = tree->lChild != NULL?tree->lChild->getLeftHeight():0;
int lrHeight = tree->lChild != NULL?tree->lChild->getRightHeight():0;
if(llHeight >= lrHeight)
rRotation(tree);
else
lrRotation(tree);
}
else if(tree->getRightHeight()-tree->getLeftHeight()>=2)
{
int rlHeight = tree->rChild != NULL?tree->rChild->getLeftHeight():0;
int rrHeight = tree->rChild != NULL?tree->rChild->getRightHeight():0;
if(rrHeight >=rlHeight)
lRotation(tree);
else
rlRotation(tree);
tree->refreshHeight();
}
}
void AVLTree::add(AVLNode* &tree, AVLNode* node)
{
if(tree == NULL)
{
node->treeHeight = 1;
tree = node;
return;
}
if(node->nodeValue < tree->nodeValue)
{
add(tree->lChild, node);
}
else
{
add(tree->rChild, node);
}
tree->refreshHeight();
rebalance(tree);
}
void AVLTree::remove(AVLNode* &tree, AVLNode* node)
{
if(tree == NULL)
{
return;
}
if(tree == node)
{
if(tree->lChild != NULL && tree->rChild != NULL)
{
AVLNode* temp = tree->lChild;
while (temp->rChild != NULL)
{
temp = temp->rChild;
}
remove(tree->lChild,temp);
temp->lChild = tree->lChild;
temp->rChild = tree->rChild;
tree = temp;
}
else
{
tree = tree->lChild!=NULL?tree->lChild:tree->rChild;
if(tree == NULL) return;
}
}
else
{
if(node->nodeValue <= tree->nodeValue&& tree->lChild != NULL )
{
remove(tree->lChild, node);
}
if(node->nodeValue >= tree->nodeValue && tree->rChild != NULL )
{
remove(tree->rChild, node);
}
}
tree->refreshHeight();
rebalance(tree);
}
void AVLTree::add(AVLNode* node)
{
add(root, node);
}
void AVLTree::remove(AVLNode* node)
{
remove(root, node);
}
vector<AVLNode*> AVLTree::getSortedList()
{
vector<AVLNode*> nodeList;
if(root != NULL)
{
vector<AVLNode*> stack;
vector<int> flagStack;
stack.push_back(root);
flagStack.push_back(0);
while(stack.size() > 0)
{
AVLNode* cNode = stack[stack.size()-1];
int flag = flagStack[flagStack.size()-1];
if(flag == 0)
{
flagStack[stack.size()-1]++;
if(cNode->lChild != NULL)
{
stack.push_back(cNode->lChild); //访问左子树
flagStack.push_back(0);
}
}
else if(flag == 1)
{
flagStack[flagStack.size()-1]++;
//左子树访问完毕
nodeList.push_back(cNode);
if(cNode->rChild != NULL)
{
stack.push_back(cNode->rChild);
flagStack.push_back(0);
}
}
else
{
stack.pop_back();
flagStack.pop_back();
}
}
}
return nodeList;
}
AVLTreeTest.h源码
#include "stdafx.h"
#include "AVLTree.h"
#include <iostream>
#include <vector>
using namespace std;
class AVLTreeTest
{
public:
void Print(AVLTree* tree)
{
vector<AVLNode*> sortedList = tree->getSortedList();
for(int i=0; i<sortedList.size(); i++)
{
cout << sortedList[i]->nodeValue << " ";
}
cout <<endl;
}
void DoTest()
{
AVLTree* tree = new AVLTree();
vector<AVLNode*> nodeList;
nodeList.push_back(new AVLNode(43));
nodeList.push_back(new AVLNode(14));
nodeList.push_back(new AVLNode(32));
nodeList.push_back(new AVLNode(52));
nodeList.push_back(new AVLNode(126));
nodeList.push_back(new AVLNode(93));
nodeList.push_back(new AVLNode(131));
nodeList.push_back(new AVLNode(44));
nodeList.push_back(new AVLNode(123));
nodeList.push_back(new AVLNode(9));
for(int i=0;i<nodeList.size();i++)
{
tree->add(nodeList[i]);
Print(tree);
}
for(int i=nodeList.size()-1; i>=0;i--)
{
int r = i*rand()/RAND_MAX;
tree->remove(nodeList[r]);
nodeList.erase(nodeList.begin()+r);
Print(tree);
}
}
};