实现代码:
#include <vector>
using namespace std;
template<class T>
class BinarySearchTree{
private:
struct BinaryNode{
T data;
BinaryNode *left;
BinaryNode *right;
BinaryNode(const T &thedata, BinaryNode *lt, BinaryNode *rt)
:data(thedata),left(lt),right(rt){}
};
public:
BinarySearchTree(BinaryNode *t=NULL){
root =t;
}
~BinarySearchTree(){
makeEmpty(root);
}
bool find(const T &x);
void insert(const T &x);
void remove(const T &x);
int height();
private:
BinaryNode *root;
vector<BinaryNode*> vNode;
void insert(const T &x, BinaryNode* &t);
void remove(const T &x, BinaryNode* &t);
bool find(const T &x, BinaryNode* &t);
void makeEmpty(BinaryNode* &t);
void tranverse(BinaryNode* &t);
int height(BinaryNode* &t);
};
//递归求树高
template<class T>
int BinarySearchTree<T>::height(){
return height(root);
}
template<class T>
int BinarySearchTree<T>::height(BinaryNode* &t){
if(NULL==t)
return 0;
else
return (height(t->left)+1)>(height(t->right)+1)?(height(t->left)+1):(height(t->right)+1);
}
template<class T>
void BinarySearchTree<T>::tranverse(BinaryNode* &t){
if(NULL==t)
return;
else{
tranverse(t->left);
vNode.push_back(t);
tranverse(t->right);
}
}
template<class T>
void BinarySearchTree<T>::makeEmpty(BinaryNode* &t){
tranverse(t);
vector<BinaryNode*>::iterator itNode;
for (itNode=vNode.begin();itNode!=vNode.end();itNode++)
{
delete (*itNode);
}
vNode.clear();
}
template<class T>
bool BinarySearchTree<T>::find(const T &x){
return find(x,root);
}
template<class T>
bool BinarySearchTree<T>::find(const T &x, BinaryNode* &t){
if(NULL==t)
return false;
else if(x==t->data)
return true;
else if(x<t->data)
return find(x,t->left);
else if(x>t->data)
return find(x,t->right);
}
template<class T>
void BinarySearchTree<T>::insert(const T &x){
insert(x,root);
}
template<class T>
void BinarySearchTree<T>::insert(const T &x, BinaryNode* &t){
if(NULL==t)
t=new BinaryNode(x,NULL,NULL);
else if(x<t->data)
insert(x,t->left);
else if(x>t->data)
insert(x,t->right);
}
template<class T>
void BinarySearchTree<T>::remove(const T &x){
remove(x,root);
}
template<class T>
void BinarySearchTree<T>::remove(const T &x, BinaryNode* &t){
if(NULL==t)
return;
else if(x<t->data)
remove(x,t->left);
else if(x>t->data)
remove(x,t->right);
else if(t->left!=NULL && t->right!=NULL){
BinaryNode *tmp=t->right;
while (tmp->left!=NULL)
tmp=tmp->left;
t->data=tmp->data;
remove(t->data,t->right);
}
else{
BinaryNode *oldNode=t;
t=(t->left!=NULL)?t->left:t->right;
delete oldNode;
}
}
测试代码:
#include <iostream>
#include <vector>
#include "main.h"
using namespace std;
int main(){
int a[]={10,8,6,21,87,56,4,0,11,3,22,7,5,34,1,2,9};
BinarySearchTree<int> tree;
for(int i=0;i<17;i++)
tree.insert(a[i]);
cout << endl;
cout << tree.height() << endl;
cout << tree.find(2) << endl;
tree.remove(2);
cout << tree.find(2) << endl;
cout << tree.height() << endl;
tree.remove(11);
cout << tree.height() << endl;
cout << tree.find(100) << endl;
cout << tree.find(56) << endl;
return 0;
}