伸展树,久仰大名。以前知道天网的那个搜索引擎貌似就用了伸展树。今天特地来实现以下。
伸展树介绍
伸展树(Splay Tree)是特殊的二叉查找树,又叫自适应查找树。它有一项奇特的技能,它可以自调整,根据结点访问的情况自调整。伸展树当某个节点被访问时,伸展树会通过旋转使该节点成为树根。这样做的好处时时,下次要访问该结点时,能够迅速的访问到该结点。
通常数据领域有个八二法则,%20的数据占据了%80的使用率,所以伸展树是很有用的,把高频词汇聚集起来,提高多次访问效率。
伸展树甚至可能由于自调整变得极不平衡,但是它的插入,删除,查找等时间复杂度摊还时间还是O(lgN),所以并不逊色与平衡树。在高频词汇情况,优势更明显。并且伸展树在结点内部无需存放平衡因子,颜色等冗余信息,节省空间。
关于理论方面可以看这几位的博客,代码我参考的《数据结构预算法分析 C语言描述》,用C++实现了一遍。
代码如下,关键部分我会在下面解释:
头文件:
#ifndef _SPLAY_TREE_H
#define _SPLAY_TREE_H
#include <iostream>
#include <assert.h>
#include <queue>
template <typename T>
class splay_tree;
template <typename T>
class splay_tree_node {
friend class splay_tree<T>;
public:
splay_tree_node(T data = T(), splay_tree_node<T>* left = NULL, splay_tree_node<T>* right = NULL)
: data_(data), left_(left), right_(right)
{}
~splay_tree_node()
{}
private:
splay_tree_node<T>* left_;
splay_tree_node<T>* right_;
T data_;
};
template <typename T> using node_type = splay_tree_node<T>;
template <typename T>
class splay_tree {
public:
splay_tree();
~splay_tree();
public:
node_type<T>* splay(const T& key);
bool insert(const T& key);
bool remove(const T& key);
node_type<T>* find(const T& key);
void inorder_traverse() const;
void level_traverse() const;
private:
node_type<T>* splay(node_type<T>*& t, const T& key);
bool insert(node_type<T>* t, const T& key);
bool remove(node_type<T>* t, const T& key);
node_type<T>* find(node_type<T>*& t, const T& key);
node_type<T>* rotate_left(node_type<T> *t);
node_type<T>* rotate_right(node_type<T> *t);
void inorder_traverse(node_type<T>* t) const;
void level_traverse(node_type<T>* t) const;
void destroy(node_type<T> *t);
private:
node_type<T>* root_;
node_type<T>* nil_;
};
template <typename T>
splay_tree<T>::splay_tree()
: root_(NULL)
{
nil_ = new node_type<T>();
}
template <typename T>
splay_tree<T>::~splay_tree()
{
destroy(root_);
delete nil_;
nil_ = NULL;
root_ = NULL;
}
template <typename T>
inline node_type<T>* splay_tree<T>::splay(const T& key) { return splay(root_, key); }
template <typename T>
inline bool splay_tree<T>::insert(const T& key) { return insert(root_, key); }
template <typename T>
inline void splay_tree<T>::inorder_traverse() const { inorder_traverse(root_);}
template <typename T>
inline void splay_tree<T>::level_traverse() const { level_traverse(root_); }
template <typename T>
inline bool splay_tree<T>::remove(const T& key) { return remove(root_, key); }
template <typename T>
inline node_type<T>* splay_tree<T>::find(const T& key) { return find(root_, key); }
template <typename T>
node_type<T>* splay_tree<T>::splay(node_type<T>*& t, const T& key) //param 1 must use reference
{
assert(t != NULL);
node_type<T> header, *L, *R;
header.left_ = header.right_ = nil_;
L = R = &header;
/*
* notice : this usage will lead to segement fault, because the nil_ might be rotate to be the "t"
* //nil_->data_ = key;
* while(key != t->data_){
* if(key < t->data_){
* if(t->left_ == nil_)
* break;
* if(key < t->left_->data_)
* t = rotate_right(t); //rotate right
*/
nil_->data_ = key; //must do it
while(key != t->data_){
if(key < t->data_){
if(key < t->left_->data_)
t = rotate_right(t); //rotate right
if(t->left_ == nil_)
break;
R->left_ = t;
R = R->left_;
t = t->left_;
}
else{
if(key > t->right_->data_)
t = rotate_left(t); //rotate left
if(t->right_ == nil_)
break;
L->right_ = t;
L = L->right_;
t = t->right_;
}
}//break condition: key==t->data_ or t->left_==nil_ or t->right_==null_node_
L->right_ = t->left_;
R->left_ = t->right_;
t->left_ = header.right_; //notice: header->right_ !
t->right_ = header.left_;
return t;
}
template <typename T>
bool splay_tree<T>::insert(node_type<T>* t, const T& key) //no reference
{
static node_type<T>* new_node = NULL;
if(new_node == NULL) //new_node is a static value
new_node = new node_type<T>();
new_node->data_ = key;
if(t == NULL){
new_node->left_ = new_node->right_ = nil_;
root_ = new_node;
}
else{
t = splay(t, key);
if(t->data_ < key){
new_node->left_ = t;
new_node->right_ = t->right_;
t->right_ = nil_;
root_ = new_node;
}
else if(t->data_ > key){
new_node->right_ = t;
new_node->left_ = t->left_;
t->left_ = nil_;
root_ = new_node;
}
else //key already exist
return false; //if failed, next insert will have no need for new
}
new_node = NULL; //if success, so next insert will call new
return true;
}
template <typename T>
bool splay_tree<T>::remove(node_type<T>* t, const T& key)
{
bool retval = false;
node_type<T>* new_root;
t = splay(t, key);
if(t->data_ == key){
if(t->left_ == nil_)
new_root = t->right_;
else{
new_root = t->left_;
new_root = splay(t->left_, key);//find the max in t->left_, ths value is a little less than the t->right_,
new_root->right_ = t->right_; //make it be the new_root
}
delete t;
root_ = new_root;
retval = true;
}
return retval;
}
template <typename T>
node_type<T>* splay_tree<T>::find(node_type<T>*& t, const T& key)
{
t = splay(t, key);
return t->data_ == key ? t : NULL;
}
template <typename T>
node_type<T>* splay_tree<T>::rotate_left(node_type<T> *t)
{
node_type<T>* sub_right = t->right_;
t->right_ = sub_right->left_;
sub_right->left_ = t;
return sub_right;
}
template <typename T>
node_type<T>* splay_tree<T>::rotate_right(node_type<T> *t)
{
node_type<T>* sub_left = t->left_;
t->left_ = sub_left->right_;
sub_left->right_ = t;
return sub_left;
}
template <typename T>
void splay_tree<T>::inorder_traverse(node_type<T> *t) const
{
if(t != nil_){
inorder_traverse(t->left_);
std::cout<<t->data_<<' ';
inorder_traverse(t->right_);
}
}
template <typename T>
void splay_tree<T>::level_traverse(node_type<T>* t) const
{
std::queue<node_type<T>*> que;
que.push(t);
while(!que.empty()){
const int size =que.size();
for(int i=0; i<size; ++i){
node_type<T> *x = que.front();
que.pop();
std::cout<<x->data_<<' ';
if(x->left_ != nil_)
que.push(x->left_);
if(x->right_ != nil_)
que.push(x->right_);
}
}
}
template <typename T>
void splay_tree<T>::destroy(node_type<T> *t)
{
if(t != nil_){
destroy(t->left_);
destroy(t->right_);
delete t;
}
}
#endif
注意:
1.本代码中使用了一个insert()函数中局部的new_node指针,来为插入结点申请空间,这是出于效率考虑的。由于伸展树插入可能失败,所以我们可以前一次申请好,前一次可能没有用到该申请的结点供后面使用,降低了频繁申请空间的开销。
2.同时使用了nil_作为哨兵结点,并利用哨兵节点来避免判断各种错误。不过要注意给哨兵结点在splay()函数前部赋初值。
程序测试文件:
#include "splay_tree.h"
#include <iostream>
using namespace std;
int main()
{
splay_tree<int> spt;
spt.insert(10);
spt.insert(50);
spt.insert(40);
spt.insert(30);
spt.insert(20);
spt.insert(60);
spt.splay(30);
spt.inorder_traverse();
cout<<endl;
spt.level_traverse();
cout<<endl;
spt.splay(10);
spt.level_traverse();
cout<<endl;
node_type<int>* ret = spt.find(10);
assert(ret != NULL);
spt.level_traverse();
cout<<endl;
ret = spt.find(60);
assert(ret != NULL);
spt.level_traverse();
cout<<endl;
ret = spt.find(5);
assert(ret == NULL);
spt.level_traverse();
cout<<endl;
ret = spt.find(77);
assert(ret == NULL);
spt.level_traverse();
cout<<endl;
bool success = spt.remove(10);
assert(success);
spt.level_traverse();
cout<<endl;
success = spt.remove(30);
assert(success);
spt.level_traverse();
cout<<endl;
success = spt.remove(50);
assert(success);
spt.level_traverse();
cout<<endl;
success = spt.remove(77);
assert(!success);
spt.level_traverse();
cout<<endl;
return 0;
}