《算法导论》里面的代码真实简洁,看的很舒服,实现也很简单!!!
头文件:bst.h
#ifndef _BST_
#define _BST_
struct Node
{
int data;
Node *parent;
Node *lChild;
Node *rChild;
Node(int d=0,Node* p=nullptr,Node *l=nullptr,Node *r=nullptr):data(d),parent(p),lChild(l),rChild(r){}
~Node();
};
class BST
{
Node *root;
void clear(Node *root);
void transPlant(Node *pa,Node *child);
public:
BST(Node *r=nullptr):root(r){}
Node *getRoot()
{
return root;
}
Node *nextNode(Node *curr);
void insert(int val);
void erase(int val);
void preOrderTraverse(Node *root);
void inOrderTraverse(Node *root);
void postOrderTraverse(Node *root);
Node *search(int val);
~BST();
};
#endif
bst.cpp文件:注意名字一定要是.cpp才可以,不能是.c会报错的!!!
#include"bst.h"
#include<iostream>
#include<stack>
using namespace std;
Node::~Node()
{
std::cout<<data<<"析构"<<std::endl;
}
void BST::insert(int val)
{
if(root==nullptr)
root=new Node(val);
Node *curr=root;
Node *p=root;
bool left;
while(curr!=NULL)
{
left=false;
p=curr;
if(val<curr->data)
{
left=true;
curr=curr->lChild;
}
else if(val>curr->data)
curr=curr->rChild;
else
return ;//表示当前value已经存在
}
if(left==true)
p->lChild=new Node(val,p);
else
p->rChild=new Node(val,p);
}
void BST::clear(Node *root)
{
if(root==NULL)
return ;
clear(root->lChild);
clear(root->rChild);
delete root;
cout<<"析构"<<endl;
}
Node * BST::search(int val)
{
Node *curr=root;
while(curr!=NULL)
{
if(val>curr->data)
curr=curr->rChild;
else if(val<curr->data)
curr=curr->lChild;
else
return curr;
}
return nullptr;
}
BST::~BST()
{
clear(root);
}
Node * BST::nextNode(Node *curr)
{
if(curr==root && root->rChild==nullptr)//保证下面每个节点都有父节点
return nullptr;
if(curr->rChild==nullptr && curr==curr->parent->lChild)
return curr->parent;
if(curr->rChild!=nullptr)
{
Node *temp=curr->rChild;
while(temp->lChild!=nullptr)
temp=temp->lChild;
return temp;
}
if(curr==curr->parent->rChild)
{
while(curr->parent!=nullptr && curr==curr->parent->rChild)
curr=curr->parent;
return curr->parent;
}
}
void BST::transPlant(Node *curr,Node *child)
{
if(curr->parent==nullptr)
root=child;
else if(curr==curr->parent->lChild)
curr->parent->lChild=child;
else
curr->parent->rChild=child;
if(child!=nullptr)
child->parent=curr->parent;
}
void BST::erase(int val)
{
Node *curr=search(val);
if(curr==nullptr)
return ;
if(curr->lChild==nullptr)//左孩子为空,则用右孩子代替
transPlant(curr,curr->rChild);
else if(curr->rChild==nullptr)//右孩子为空,则用左孩子代替
transPlant(curr,curr->lChild);
else//两个孩子都不为空
{
Node *next=nextNode(curr);
if(next->parent!=curr)//如果下一个元素不是curr的右孩子,需要维护curr的右子树
{
transPlant(next,next->rChild);
next->rChild=curr->rChild;
next->rChild->parent=next;
}
//如果下一个元素是curr的右孩子,则用右孩子代替curr节点,同时维护curr的左子树
transPlant(curr,next);
next->lChild=curr->lChild;
next->lChild->parent=next;
}
delete curr;
}
void BST::preOrderTraverse(Node *root)
{
if(root==nullptr)
return ;
stack<Node *> s;
s.push(root);
while(!s.empty())
{
Node *temp=s.top();
cout<<temp->data<<endl;
s.pop();
if(temp->rChild!=nullptr)
s.push(temp->rChild);
if(temp->lChild!=nullptr)
s.push(temp->lChild);
}
}
void BST::inOrderTraverse(Node *root)
{
if(root==nullptr)
return ;
stack<Node *> s;
Node *curr=root;
while(curr!=nullptr)
{
s.push(curr);
curr=curr->lChild;
}
while(!s.empty())
{
curr=s.top();
cout<<curr->data<<endl;
s.pop();
if(curr->rChild!=nullptr)
{
curr=curr->rChild;
while(curr!=nullptr)
{
s.push(curr);
curr=curr->lChild;
}
}
}
}
void BST::postOrderTraverse(Node *root)
{
if(root==nullptr)
return ;
stack<Node *> s;
Node *prev=nullptr;//保存前一时刻访问的节点
Node *curr=root;
while(curr!=nullptr)
{
s.push(curr);
curr=curr->lChild;
}
while(!s.empty())
{
curr=s.top();
if(curr->rChild==prev)
{
cout<<curr->data<<endl;
s.pop();
prev=curr;
}
else
{
curr=curr->rChild;
while(curr!=nullptr)
{
s.push(curr);
curr=curr->lChild;
}
prev=nullptr;
}
}
}
测试的main文件:test.cpp:
#include"bst.h"
#include<iostream>
using namespace std;
int main()
{
BST bst;
//测试insert算法
bst.insert(7);
bst.insert(2);
bst.insert(1);
bst.insert(4);
bst.insert(3);
bst.insert(5);
bst.insert(6);
bst.insert(0);
bst.insert(8);
//测试nextNode算法
cout<<"测试nextNode算法:\n";
Node *curr=bst.getRoot()->lChild->lChild->lChild;
while(curr!=NULL)
{
cout<<curr->data<<endl;
curr=bst.nextNode(curr);
}
cout<<"测试先序遍历算法:\n"<<endl;
bst.preOrderTraverse(bst.getRoot());
cout<<"测试中序遍历算法:\n"<<endl;
bst.inOrderTraverse(bst.getRoot());
cout<<"测试后序遍历算法:\n"<<endl;
bst.postOrderTraverse(bst.getRoot());
//测试erase算法
bst.erase(1);
bst.erase(2);
bst.erase(7);
bst.erase(6);
system("pause");
return 0;
}