自顶向下伸展树的详细介绍见数据结构与算法分析java语言描述第三版 Weiss著,这里只给出实现
以下代码囊括了伸展树自底向上和自顶向下的所有实现
C++代码:
#include <iostream>
#include <stack>
#include <vector>
#include <random>
using namespace std;
struct JudgeResult //对二叉树的判断结果
{
bool isBST = true; //是否为二叉搜索树
int max_value_in_BST = 0; //二叉搜索树中最大节点值
int min_value_in_BST = 0; //二叉搜索树中的最小节点值
};
template <typename T>
struct SplayTreeNode //伸展树节点定义
{
T data;
SplayTreeNode* left_child = nullptr;
SplayTreeNode* right_child = nullptr;
SplayTreeNode(const T& d) :data(d) {}
};
template <typename T>
void RotateLR(SplayTreeNode<T>* ptr) //对以ptr为根的子树执行先左后右双旋转
{
SplayTreeNode<T>* p = ptr->left_child;
SplayTreeNode<T>* q = p->right_child;
p->right_child = q->left_child;
q->left_child = p;
ptr->left_child = q->right_child;
q->right_child = ptr;
}
template <typename T>
void RotateRL(SplayTreeNode<T>* ptr) //对以ptr为根的子树执行先右后左双旋转
{
SplayTreeNode<T>* p = ptr->right_child;
SplayTreeNode<T>* q = p->left_child;
p->left_child = q->right_child;
q->right_child = p;
ptr->right_child = q->left_child;
q->left_child = ptr;
}
template <typename T>
void RotateR(SplayTreeNode<T>* ptr) //对以ptr为根的子树执行右单旋转
{
SplayTreeNode<T>* p = ptr->left_child;
ptr->left_child = p->right_child;
p->right_child = ptr;
}
template <typename T>
void RotateL(SplayTreeNode<T>* ptr) //对以ptr为根的子树执行左单旋转
{
SplayTreeNode<T>* p = ptr->right_child;
ptr->right_child = p->left_child;
p->left_child = ptr;
}
template <typename T>
class SplayTree
{
public:
enum OperateType { INSERT, DELETE, SEARCH };
SplayTreeNode<T>* search(const T& key); //伸展树中搜索关键码
T getRootValue() { return root->data; }
bool isEmpty() { return root == nullptr; }
bool insert(const T& key); //伸展树中插入关键码
bool remove(const T& key); //伸展树中移除关键码
bool curTreeIsBST() { return isBST(root).isBST; }
void outputInorderSeq() { if (root == nullptr) { cout << "NULL"; return; } inorderTraversal(root); } //打印伸展树中序序列
SplayTree(bool u) :use_spread_topdown(u) {}
~SplayTree() { destoryTree(root); }
private:
JudgeResult isBST(SplayTreeNode<T>* root);
void inorderTraversal(SplayTreeNode<T>* root)
{
if (root != nullptr)
{
inorderTraversal(root->left_child);
cout << root->data << " ";
inorderTraversal(root->right_child);
}
}
void destoryTree(SplayTreeNode<T>* root)
{
if (root != nullptr)
{
destoryTree(root->left_child);
destoryTree(root->right_child);
delete root;
}
}
void adjustUntilRoot(SplayTreeNode<T>* cur, vector<SplayTreeNode<T>*>& work_stack, size_t top); //从cur节点开始反复旋转,直到将cur调整至根节点
bool spreadTopDownAndOP(OperateType operate_type, const T& key); //伸展树的自顶向下展开,operate_type为要执行的操作类型
void splay(vector<SplayTreeNode<T>*>& work_stack, const T& key);
SplayTreeNode<T>* root = nullptr; //伸展树根节点
bool use_spread_topdown = false; //是否使用自顶向下展开处理伸展树操作
};
template <typename T>
JudgeResult SplayTree<T>::isBST(SplayTreeNode<T>* root) //判断二叉树是否为二叉搜索树
{
if (root == nullptr)
{
return JudgeResult();
}
JudgeResult result;
if (root->left_child != nullptr)
{
JudgeResult temp = isBST(root->left_child);
if (temp.isBST == true && temp.max_value_in_BST < root->data)
{
result.min_value_in_BST = temp.min_value_in_BST;
}
else
{
result.isBST = false;
}
}
else
{
result.min_value_in_BST = root->data;
}
if (root->right_child != nullptr)
{
JudgeResult temp = isBST(root->right_child);
if (temp.isBST == true && temp.min_value_in_BST > root->data)
{
result.max_value_in_BST = temp.max_value_in_BST;
}
else
{
result.isBST = false;
}
}
else
{
result.max_value_in_BST = root->data;
}
return result;
}
template <typename T>
SplayTreeNode<T>* SplayTree<T>::search(const T& key)
{
if (use_spread_topdown)
{
if (spreadTopDownAndOP(OperateType::SEARCH, key))
{
return root;
}
else
{
return nullptr;
}
}
else
{
SplayTreeNode<T>* cur = root;
if (cur == nullptr)
{
return nullptr;
}
vector<SplayTreeNode<T>*> work_stack;
while (cur != nullptr)
{
if (cur->data == key)
{
adjustUntilRoot(cur, work_stack, work_stack.size());
return root;
}
work_stack.push_back(cur);
if (key < cur->data)
{
cur = cur->left_child;
}
else
{
cur = cur->right_child;
}
}
splay(work_stack, key);
return nullptr;
}
}
template <typename T>
void updateLeft(SplayTreeNode<T>*& left_tree, SplayTreeNode<T>*& left_joint_point, SplayTreeNode<T>* change_ptr)
{
if (left_tree == nullptr)
{
left_tree = change_ptr;
}
else
{
left_joint_point->right_child = change_ptr;
}
left_joint_point = change_ptr;
}
template <typename T>
void updateRight(SplayTreeNode<T>*& right_tree, SplayTreeNode<T>*& right_joint_point, SplayTreeNode<T>* change_ptr)
{
if (right_tree == nullptr)
{
right_tree = change_ptr;
}
else
{
right_joint_point->left_child = change_ptr;
}
right_joint_point = change_ptr;
}
template <typename T>
void rotate(SplayTreeNode<T>*& cur, stack<SplayTreeNode<T>*>& work_stack, SplayTreeNode<T>*& left_tree, SplayTreeNode<T>*& right_tree, SplayTreeNode<T>*& left_joint_point, SplayTreeNode<T>*& right_joint_point) //自顶向下展开的旋转操作
{
if (cur == nullptr || work_stack.size() == 3)
{
cur = work_stack.top();
work_stack.pop();
}
if (work_stack.empty())
return;
SplayTreeNode<T>* p = work_stack.top();
work_stack.pop();
if (work_stack.empty() == true)
{
if (p->left_child == cur)
{
p->left_child = nullptr;
updateRight(right_tree, right_joint_point, p);
}
else
{
p->right_child = nullptr;
updateLeft(left_tree, left_joint_point, p);
}
}
else
{
SplayTreeNode<T>* q = work_stack.top();
work_stack.pop();
if (p->left_child == cur)
{
p->left_child = nullptr;
if (q->left_child == p)
{
RotateR(q);
updateRight(right_tree, right_joint_point, p);
}
else
{
q->right_child = nullptr;
updateLeft(left_tree, left_joint_point, q);
updateRight(right_tree, right_joint_point, p);
}
}
else
{
p->right_child = nullptr;
if (q->left_child == p)
{
q->left_child = nullptr;
updateLeft(left_tree, left_joint_point, p);
updateRight(right_tree, right_joint_point, q);
}
else
{
RotateL(q);
updateLeft(left_tree, left_joint_point, p);
}
}
}
}
template <typename T>
void Union(SplayTreeNode<T>* cur, SplayTreeNode<T>* left_tree, SplayTreeNode<T>* right_tree, SplayTreeNode<T>* left_joint_point, SplayTreeNode<T>* right_joint_point)
{
if (left_tree != nullptr)
{
left_joint_point->right_child = cur->left_child;
cur->left_child = left_tree;
}
if (right_tree != nullptr)
{
right_joint_point->left_child = cur->right_child;
cur->right_child = right_tree;
}
}
template <typename T>
void rightReplace(SplayTreeNode<T>* cur)
{
SplayTreeNode<T>* p = cur->right_child;
if (p->left_child != nullptr)
{
SplayTreeNode<T>* parent = nullptr;
while (p->left_child != nullptr)
{
parent = p;
p = p->left_child;
}
parent->left_child = p->right_child;
}
else
{
cur->right_child = p->right_child;
}
cur->data = p->data;
delete p;
}
template <typename T>
void leftReplace(SplayTreeNode<T>* cur)
{
SplayTreeNode<T>* p = cur->left_child;
if (p->right_child != nullptr)
{
SplayTreeNode<T>* parent = nullptr;
while (p->right_child != nullptr)
{
parent = p;
p = p->right_child;
}
parent->right_child = p->left_child;
}
else
{
cur->left_child = p->left_child;
}
cur->data = p->data;
delete p;
}
template <typename T>
SplayTreeNode<T>* removeAtRoot(SplayTreeNode<T>* cur) //删除根节点cur,并用其前驱后继值替代之
{
if (cur->right_child != nullptr)
{
rightReplace(cur);
}
else if (cur->left_child != nullptr)
{
leftReplace(cur);
}
else
{
delete cur;
return nullptr;
}
return cur;
}
template <typename T>
bool SplayTree<T>::spreadTopDownAndOP(SplayTree<T>::OperateType operate_type, const T& key)
{
SplayTreeNode<T>* cur = root;
if (cur == nullptr)
{
if (operate_type == OperateType::INSERT)
{
root = new SplayTreeNode<T>(key);
return true;
}
else
{
return false;
}
}
SplayTreeNode<T>* left_tree = nullptr;
SplayTreeNode<T>* right_tree = nullptr;
SplayTreeNode<T>* left_joint_point = nullptr;
SplayTreeNode<T>* right_joint_point = nullptr;
bool has_inserted = false;
while (true)
{
stack<SplayTreeNode<T>*> temp_stack;
int i = 0;
for (i = 1; i <= 3; ++i)
{
if (cur == nullptr || cur->data == key)
{
break;
}
temp_stack.push(cur);
if (key < cur->data)
{
cur = cur->left_child;
}
else
{
cur = cur->right_child;
}
}
if (cur == nullptr)
{
if (operate_type == OperateType::INSERT)
{
has_inserted = true;
if (key < temp_stack.top()->data)
{
cur = temp_stack.top()->left_child = new SplayTreeNode<T>(key);
}
else
{
cur = temp_stack.top()->right_child = new SplayTreeNode<T>(key);
}
}
else
{
rotate(cur, temp_stack, left_tree, right_tree, left_joint_point, right_joint_point);
root = cur;
Union(root, left_tree, right_tree, left_joint_point, right_joint_point);
return false;
}
}
else
{
if (i == 1)
{
Union(cur, left_tree, right_tree, left_joint_point, right_joint_point);
if (operate_type == OperateType::DELETE)
{
root = removeAtRoot(cur);
return true;
}
else
{
root = cur;
if (operate_type == OperateType::INSERT)
{
if (has_inserted)
return true;
return false;
}
return true;
}
}
}
rotate(cur, temp_stack, left_tree, right_tree, left_joint_point, right_joint_point);
}
}
template <typename T>
void SplayTree<T>::adjustUntilRoot(SplayTreeNode<T>* cur, vector<SplayTreeNode<T>*>& work_stack, size_t top)
{
while (top >= 1)
{
SplayTreeNode<T>* p = work_stack[--top];
if (top == 0)
{
if (p->left_child == cur)
{
RotateR(p);
}
else
{
RotateL(p);
}
}
else
{
SplayTreeNode<T>* q = work_stack[--top];
if (p->left_child == cur)
{
if (q->left_child == p)
{
RotateR(q); //一字形旋转
RotateR(p);
}
else
{
RotateRL(q); //之字形旋转
}
}
else
{
if (q->left_child == p)
{
RotateLR(q); //之字形旋转
}
else
{
RotateL(q); //一字形旋转
RotateL(p);
}
}
if (top >= 1) //与上层重新连接
{
if (work_stack[top - 1]->left_child == q)
{
work_stack[top - 1]->left_child = cur;
}
else
{
work_stack[top - 1]->right_child = cur;
}
}
}
}
root = cur;
}
template <typename T>
bool SplayTree<T>::insert(const T& key)
{
if (use_spread_topdown)
{
return spreadTopDownAndOP(OperateType::INSERT, key);
}
SplayTreeNode<T>* cur = root;
if (cur == nullptr)
{
root = new SplayTreeNode<T>(key);
return true;
}
vector<SplayTreeNode<T>*> work_stack;
while (cur != nullptr)
{
if (cur->data == key)
{
adjustUntilRoot(cur, work_stack, work_stack.size());
return false;
}
work_stack.push_back(cur);
if (key < cur->data)
{
cur = cur->left_child;
}
else
{
cur = cur->right_child;
}
}
if (key < work_stack.back()->data)
{
cur = work_stack.back()->left_child = new SplayTreeNode<T>(key);
}
else
{
cur = work_stack.back()->right_child = new SplayTreeNode<T>(key);
}
adjustUntilRoot(cur, work_stack, work_stack.size());
return true;
}
template <typename T>
void SplayTree<T>::splay(vector<SplayTreeNode<T>*>& work_stack, const T& key)
{
SplayTreeNode<T>* cur = work_stack.back();
size_t d = work_stack.size();
if (key < cur->data)
{
while (true)
{
--d;
if (d == 0)
break;
if (work_stack[d - 1]->right_child == cur)
break;
cur = work_stack[d - 1];
}
}
else
{
while (true)
{
--d;
if (d == 0)
break;
if (work_stack[d - 1]->left_child == cur)
break;
cur = work_stack[d - 1];
}
}
if (d == 0)
{
adjustUntilRoot(work_stack.back(), work_stack, work_stack.size() - 1);
}
else
{
cur = work_stack[d - 1];
--d;
adjustUntilRoot(cur, work_stack, d);
}
}
template <typename T>
bool SplayTree<T>::remove(const T& key) //从伸展树中移除节点后,如果中序序列中该节点存在后继节点,则把后继节点调整至树根,否则如果存在前驱节点,则把前驱节点调整至树根,否则直接删除该节点
{
if (use_spread_topdown)
{
return spreadTopDownAndOP(OperateType::DELETE, key);
}
SplayTreeNode<T>* cur = root;
vector<SplayTreeNode<T>*> work_stack;
while (cur != nullptr)
{
if (cur->data == key)
{
break;
}
work_stack.push_back(cur);
if (key < cur->data)
{
cur = cur->left_child;
}
else
{
cur = cur->right_child;
}
}
if (cur == nullptr)
{
if (root != nullptr)
{
splay(work_stack, key);
}
return false;
}
size_t d = work_stack.size();
if (cur->left_child != nullptr && cur->right_child != nullptr || work_stack.empty() && cur->right_child != nullptr) //a b c d
{
rightReplace(cur);
}
else if (work_stack.empty() == false && cur->left_child != nullptr) // e, f
{
if (cur == work_stack.back()->left_child)
{
work_stack.back()->left_child = cur->left_child;
}
else
{
work_stack.back()->right_child = cur->left_child;
SplayTreeNode<T>* p = cur->left_child;
while (d >= 1) // 寻找cur在中序序列中的后继节点
{
if (work_stack[d - 1]->left_child == p)
{
break;
}
p = work_stack[--d];
}
if (d == 0)
{
p = cur->left_child;
delete cur;
while (p->right_child != nullptr)
{
work_stack.push_back(p);
p = p->right_child;
}
adjustUntilRoot(p, work_stack, work_stack.size());
return true;
}
}
delete cur;
cur = work_stack[--d];
}
else if (work_stack.empty() && cur->left_child != nullptr) // g
{
leftReplace(cur);
}
else if (work_stack.empty() == false && cur->right_child != nullptr) // h i
{
if (cur == work_stack.back()->right_child)
{
work_stack.back()->right_child = cur->right_child;
}
else
{
work_stack.back()->left_child = cur->right_child;
SplayTreeNode<T>* p = cur->right_child;
while (d >= 1) // 寻找cur在中序序列中的后继节点
{
if (work_stack[d - 1]->right_child == p)
{
break;
}
p = work_stack[--d];
}
if (d == 0)
{
p = cur->right_child;
delete cur;
while (p->left_child != nullptr)
{
work_stack.push_back(p);
p = p->left_child;
}
adjustUntilRoot(p, work_stack, work_stack.size());
return true;
}
}
delete cur;
cur = work_stack[--d];
}
else if (work_stack.empty() == false) // k, l
{
if (cur == work_stack.back()->left_child)
{
work_stack.back()->left_child = nullptr;
}
else
{
work_stack.back()->right_child = nullptr;
}
delete cur;
cur = work_stack[--d];
}
else
{
delete cur;
cur = root = nullptr;
}
adjustUntilRoot(cur, work_stack, d);
return true;
}
int main()
{
const int N = 2000;
SplayTree<int> test_obj(true);
vector<int> test_data(N);
for (int i = 0; i < N; ++i)
{
test_data[i] = i + 1;
}
/*for (int i = 0; i < N; ++i)
{
test_data.push_back(i + 1);
}*/
shuffle(test_data.begin(), test_data.end(), default_random_engine());
for (const int& i : test_data)
{
cout << "插入关键码" << i << endl;
cout << endl;
if (test_obj.insert(i))
{
cout << "插入成功" << endl;
if (test_obj.curTreeIsBST())
{
cout << "当前树是二叉搜索树" << endl;
cout << "根节点关键码";
if (test_obj.isEmpty())
{
cout << "NULL" << endl;
}
else
{
cout << test_obj.getRootValue() << endl;
}
//cout << "中序序列为";
//test_obj.outputInorderSeq();
//cout << endl;
}
else
{
cout << "ERROR:当前树不是二叉搜索树" << endl;
exit(-1);
}
}
else
{
cout << "插入失败" << endl;
cout << "根节点关键码";
if (test_obj.isEmpty())
{
cout << "NULL" << endl;
}
else
{
cout << test_obj.getRootValue() << endl;
}
//cout << "中序序列为";
//test_obj.outputInorderSeq();
//cout << endl;
}
cout << endl;
}
/*for (int i = 0; i < N; ++i)
{
++test_data[i];
}*/
/*for (const int& i : test_data)
{
cout << "搜索关键码" << i << endl;
cout << endl;
if (test_obj.search(i))
{
cout << "搜索成功" << endl;
if (test_obj.curTreeIsBST())
{
cout << "当前树是二叉搜索树" << endl;
cout << "根节点关键码";
if (test_obj.isEmpty())
{
cout << "NULL" << endl;
}
else
{
cout << test_obj.getRootValue() << endl;
}
//cout << "中序序列为";
//test_obj.outputInorderSeq();
//cout << endl;
}
else
{
cout << "ERROR:当前树不是二叉搜索树" << endl;
exit(-1);
}
}
else
{
cout << "搜索失败" << endl;
cout << "根节点关键码";
if (test_obj.isEmpty())
{
cout << "NULL" << endl;
}
else
{
cout << test_obj.getRootValue() << endl;
}
//cout << "中序序列为";
//test_obj.outputInorderSeq();
//cout << endl;
}
cout << endl;
}*/
size_t count = 0;
for (const int& i : test_data)
{
cout << "删除关键码" << i << endl;
cout << endl;
if (test_obj.remove(i))
{
cout << "删除成功" << endl;
if (test_obj.curTreeIsBST())
{
cout << "当前树是二叉搜索树" << endl;
cout << "根节点关键码";
if (test_obj.isEmpty())
{
cout << "NULL" << endl;
}
else
{
cout << test_obj.getRootValue() << endl;
}
//cout << "中序序列为";
//test_obj.outputInorderSeq();
//cout << endl;
}
else
{
cout << "ERROR:当前树不是二叉搜索树" << endl;
exit(-1);
}
}
else
{
cout << "删除失败" << endl;
cout << "根节点关键码";
if (test_obj.isEmpty())
{
cout << "NULL" << endl;
}
else
{
cout << test_obj.getRootValue() << endl;
}
//cout << "中序序列为";
//test_obj.outputInorderSeq();
//cout << endl;
}
cout << endl;
cout << endl;
}
return 0;
}