二叉搜索树与双向链表的相同点是:
二叉搜索树右左右指针域,而双向链表也有前去与后继,这样只要按照中序遍历的方法,将二叉树遍历一遍,改变原来的左右指针域的指向,就能变成双向链表。
首先将搜索二叉树写成迭代器,迭代器的本质就是指针,用指针++或–的方式,依次获得链表的节点。
template<class K, class V>
struct BSTNode
{
BSTNode(const K& key = K(), const V& value = V())
:_pLeft(NULL)
,_pRight(NULL)
,_pParent(NULL)
,_key(key)
,_value(value)
{}
BSTNode<K, V>* _pLeft;
BSTNode<K, V>* _pRight;
BSTNode<K, V>* _pParent;
K _key;
V _value;
};
class BSTIterator
{
typedef BSTNode<K, V> Node;
typedef BSTIterator<K, V, Ref, Ptr> Self;
public:
BSTIterator()
:_pNode(NULL)
{}
BSTIterator(Node* pNode)
:_pNode(_pNode)
{}
BSTIterator(const Self& s)
:_pNode(s._pNode)
{}
Self& operator++()
{
Increment();
return *this;
}
Self operator++(int)
{
Self pCur(*this);
Increment();
return pCur;
}
Self& operator--()
{
Decrement();
return *this;
}
Self operator--(int)
{
Self pCur(*this);
Decrement();
return pCur;
}
Ref operator*()
{
return _pNode->_key;
}
Ptr operator->()
{
return &(operator*());
}
bool operator!=(const Self& s)
{
return _pNode != s._pNode;
}
bool operator==(const Self& s)
{
return _pNode == s._pNode;
}
protected:
// 取当前结点的下一个结点
void Increment()
{
//如果右孩子存在,取右孩子中的最小节点
if(_pNode->_pRight)
{
_pNode = _pNode->_pRight;
while(_pNode->_pRight)
_pNode = _pNode->_pLeft;
}
else
{
Node* pParent = _pNode->_pParent;
while(pParent->_pRight == _pNode)
{
_pNode = pParent;
pParent = pParent->_pParent;
}
if(_pNode->_pRight != pParent)
_pNode = pParent;
}
}
// --取前一个小的结点,在left子树中
void Decrement()
{
if(_pNode->_pLeft)
{
_pNode = _pNode->_pLeft;
while(_pNode->_pRight)
_pNode = _pNode->_pRight;
}
else
{
Node* pParent = _pNode->_pParent;
while(pParent->_pLeft == _pNode)
{
_pNode = pParent;
pParent = pParent->_pParent;
}
_pNode = pParent;
}
}
protected:
Node* _pNode;
};
在该类中需要重载许多运算符。
因为迭代器的begin()函数与end()函数是前闭后开的,表明end()的位置是取不到的,所以为了方便起见,需要定义一个头节点来表明这个位置。
所以他的插入与删除函数就与一般搜索二叉树有点不同了。
template <class K, class V>
class BSTree
{
typedef BSTNode<K, V> Node;
public:
typedef BSTIterator<K, V, K&, K*> Iterator;
BSTree()
{
_pHead = new Node();
_pHead->_pLeft = _pHead;
_pHead->_pRight = _pHead;
_pHead->_pParent = NULL;
}
BSTree(const BSTree& bst)
{
Node* pRoot = GetRoot();
pRoot = _CopyBSTree(bst.pRoot);
}
BSTree<K, V>& operator=(const BSTree<K, V>& bst)
{
Node* pRoot = GetRoot();
if(&bst != this)
{
_Destroy(pRoot);
pRoot = _CopyBSTree(bst.pRoot);
}
return *this;
}
Iterator Begin()
{
return _pHead->_pLeft;
}
Iterator End()
{
return _pHead;
}
bool Insert(const K& key, const V& value)
{
Node*& pRoot = _pHead->_pParent;
if(pRoot == NULL)
{
pRoot = new Node(key, value);
pRoot->_pParent = _pHead;
}
else
{
Node* pCur = pRoot;
Node* pParent = NULL;
while(pCur)
{
if(key < pCur->_key)
{
pParent = pCur;
pCur = pCur->_pLeft;
}
else if(key > pCur->_key)
{
pParent = pCur;
pCur = pCur->_pRight;
}
else
return false;
}
pCur = new Node(key, value);
if(pParent->_key > key)
pParent->_pLeft = pCur;
else
pParent->_pRight = pCur;
pCur->_pParent = pParent;
}
_pHead->_pLeft = GetMinKey();
_pHead->_pRight = GetMaxKey();
return true;
}
Node* Find(const K& key)
{
Node* pRoot = GetRoot();
if(pRoot)
{
Node* pCur = _pHead->_pParent;
while(pCur)
{
if(pCur->_key == key)
return pCur;
else if(pCur->_key > key)
pCur = pCur->_pLeft;
else
pCur = pCur->_pRight;
}
}
return NULL;
}
bool Remove(const K& key)
{
Node*& pRoot = GetRoot();
if(pRoot == NULL)
return false;
if(pRoot->_pLeft == NULL && pRoot->_pRight == NULL && pRoot->_key == key)
{
delete pRoot;
_pHead->_pParent = NULL;
}
else
{
Node* pCur = pRoot;
Node* pParent = NULL;
//先找到要删除的节点
while(pCur)
{
if(pCur->_key > key)
{
pParent = pCur;
pCur = pCur->_pLeft;
}
else if(pCur->_key < key)
{
pParent = pCur;
pCur = pCur->_pRight;
}
else
break;
}
if(pCur == NULL)
return false;
else
{
//左孩子为空,右孩子可能为空
if(pCur->_pLeft == NULL)
{
//当前节点不是根节点
if(pCur != pRoot)
{
//判断当前节点是其双亲的左孩子还是右孩子
if(pCur == pParent->_pLeft)
pParent->_pLeft = pCur->_pRight;
else
pParent->_pRight = pCur->_pRight;
}
else
pRoot = pCur->_pRight;
}
//右孩子为空,左孩子不为空
else if(pCur->_pRight == NULL)
{
if(pCur != pRoot)
{
if(pCur == pParent->_pLeft)
pParent->_pLeft = pCur->_pLeft;
else
pParent->_pRight = pCur->_pLeft;
}
else
pRoot = pCur->_pLeft;
}
//左右孩子都不为空
else
{
//先找到右子树的最左结点
Node* MinNodeInRightTree = pCur->_pRight;
pParent = pCur;
while(MinNodeInRightTree->_pLeft)
{
pParent = MinNodeInRightTree;
MinNodeInRightTree = MinNodeInRightTree->_pLeft;
}
//将最左结点与当前节点值交换
pCur->_key = MinNodeInRightTree->_key;
pCur->_value = MinNodeInRightTree->_value;
//问题转化成左孩子为空
if(MinNodeInRightTree == pParent->_pLeft)
pParent->_pLeft = MinNodeInRightTree->_pRight;
else
pParent->_pRight = MinNodeInRightTree->_pRight;
pCur = MinNodeInRightTree;
}
delete pCur;
}
}
_pHead->_pLeft = GetMinKey();
_pHead->_pRight = GetMaxKey();
return true;
}
Node* GetMaxKey()
{
return _GetMaxKey(GetRoot());
}
Node* GetMinKey()
{
return _GetMinKey(GetRoot());
}
Node*& GetRoot()
{
return _pHead->_pParent;
}
void InOrder()
{
cout<<"InOrder: ";
_InOrder(GetRoot());
cout<<endl;
}
Node* ToList()
{
//链表的头结点就是最左结点
Node* pHead = GetRoot();
Node* pPre = NULL;
while(pHead->_pLeft)
pHead = pHead->_pLeft;
_ToList(GetRoot(), pPre);
return pHead;
}
private:
void _ToList(Node* pRoot, Node*& pPre)
{
if(pRoot)
{
_ToList(pRoot->_pLeft, pPre);
pRoot->_pLeft = pPre;
if(pPre)
pPre->_pRight = pRoot;
pPre = pRoot;
_ToList(pRoot->_pRight, pPre);
}
}
void _Destroy(Node*& pRoot)
{
if(pRoot)
{
_Destroy(pRoot->_pLeft);
_Destroy(pRoot->_pRight);
delete pRoot;
pRoot = NULL;
}
}
Node* _CopyBSTree(Node* pRoot)
{
Node* pCur = NULL;
if(pRoot)
{
pCur = new Node(pRoot->_key, pRoot->_value);
pCur->_pLeft = _CopyBSTree(pRoot->_pLeft);
pCur->_pRight = _CopyBSTree(pRoot->_pRight);
}
return pCur;
}
void _InOrder(Node* pRoot)
{
if(pRoot)
{
_InOrder(pRoot->_pLeft);
cout<<pRoot->_key<<" ";
_InOrder(pRoot->_pRight);
}
}
Node* _GetMaxKey(Node* pRoot)
{
while(pRoot->_pRight)
pRoot = pRoot->_pRight;
return pRoot;
}
Node* _GetMinKey(Node* pRoot)
{
while(pRoot->_pLeft)
pRoot = pRoot->_pLeft;
return pRoot;
}
private:
Node* _pHead;
};
调用方式:
#include "BSTIterator.h"
int main()
{
int arr[] = {5,3,4,1,7,8,2,6,0,9};
BSTree<int, int> bt;
for(int i=0; i<sizeof(arr)/sizeof(arr[0]); ++i)
{
bt.Insert(arr[i], i);
}
bt.InOrder();
BSTNode<int, int>* pos = bt.ToList();
while(pos)
{
cout<<pos->_key<<" ";
pos = pos->_pRight;
}
/*BSTree<int, int> ::Iterator it = bt.Begin();
while(it != bt.End())
{
cout<<*it<<" ";
it++;
}
cout<<endl;*/
system("pause");
return 0;
}