400多行代码(其实还有好多API没完成),请接受恐惧的代码吧得到节点count是递归的(为了不太复杂),下面也给出了测试代码
BST.h
#pragma once
#include <stdexcept>
#include <stack>
template<typename Key, typename Value>
class BST
{
private:
class Node
{
public:
Node* left = nullptr;
Node* right = nullptr;
Key key;
Value value;
int count = 0;
public:
Node(const Key& k, const Value& v, const int& c) :key(k), value(v), count(c)
{
}
};
Node* root = nullptr;
private:
/********************************************
函数名称: size
函数说明: 返回二叉树节点个数
返回值: int
********************************************/
int size(Node* r)
{
if (r == nullptr)
return 0;
return r->count;
}
/******************************************
函数名称: sync_size
函数说明: 更新各节点的count值
返回值: void
*******************************************/
void sync_size(Node* r)
{
if (r == nullptr)
return;
sync_size(r->left);
sync_size(r->right);
r->count = size(r->left) + size(r->right) + 1;
return;
}
/*******************************************
函数名称: put
函数说明: 存储节点于二叉搜索树
返回值: Node*
*******************************************/
Node* put(Node* r, const Key& k, const Value& v)
{
if (r == nullptr)
return new Node(k, v, 1);
Node* curr = r;
while (true)
{
if (curr->key > k)
{
if (curr->left == nullptr)
{
curr->left = new Node(k, v, 1);
break;
}
else
curr = curr->left;
}
else if (curr->key < k)
{
if (curr->right == nullptr)
{
curr->right = new Node(k, v, 1);
break;
}
else
curr = curr->right;
}
else
{
curr->value = v;
return r;
}
}
sync_size(r);
return r;
}
/*******************************************
函数名称: get
函数说明: 得到二叉树key对应value
返回值: Value
******************************************/
Value get(Node* r, const Key& key)
{
if (r == nullptr)
throw std::out_of_range("can't get the value of key");
while (true)
{
if (r->key > key)
{
if (r->left == nullptr)
throw std::out_of_range("can't get the value of key");
else
r = r->left;
}
else if (r->key < key)
{
if (r->right == nullptr)
throw std::out_of_range("can't get the value of key");
else
r = r->right;
}
else
return r->value;
}
}
/******************************************
函数名称: min
函数说明: 取得二叉树key最小的节点
返回值: Node*
********************************************/
Node* min(Node* r)
{
if (r == nullptr)
throw std::out_of_range("can't gain the min");
while (true)
{
if (r->left == nullptr)
return r;
else
r = r->left;
}
}
/******************************************
函数名称: max
函数说明: 取得二叉树里key最大的节点
返回值: Node*
*******************************************/
Node* max(Node* r)
{
if (r == nullptr)
throw std::out_of_range("can't gain the max");
while (true)
{
if (r->right == nullptr)
return r;
else
r = r->right;
}
}
/******************************************
函数名称: deleteMin
函数说明: 删除二叉树中key最小的节点
返回值: Node*
*******************************************/
Node* deleteMin(Node *r)
{
if (r == nullptr)
return nullptr;
Node* curr = r;
Node* pre = nullptr;
while (true)
{
if (curr->left == nullptr)
{
Node* temp = curr;
curr = curr->right;
delete temp;
if (pre != nullptr)
pre->left = curr;
else
r = curr;
sync_size(r);
return r;
}
else
{
pre = curr;
curr = curr->left;
}
}
}
/******************************************
函数名称: deleteMax
函数说明: 删除二叉树key最大的节点
返回值: void
*******************************************/
Node* deleteMax(Node *r)
{
if (r == nullptr)
return nullptr;
Node* curr = r;
Node* pre = nullptr;
while (true)
{
if (curr->right == nullptr)
{
Node* temp = curr;
curr = curr->left;
delete temp;
if (pre != nullptr)
pre->right = curr;
else
r = curr;
sync_size(r);
return r;
}
else
{
pre = curr;
curr = curr->right;
}
}
}
/******************************************
函数名称: before_display
函数说明: 中序遍历二叉树打印节点value
返回值: void
*******************************************/
void before_display(Node* r)
{
if (r == nullptr)
return;
stack<Node*> s;
while (r != nullptr || !s.empty())
{
if (r != nullptr)
{
s.push(r);
r = r->left;
}
else
{
r = s.top();
cout << r->value << ends;
s.pop();
r = r->right;
}
}
}
/******************************************
函数名称: erase
函数说明: 删除特定key的节点
*******************************************/
Node* erase(Node* r, const Key& k)
{
if (r == nullptr)
return nullptr;
Node* curr = r;
Node* pre = nullptr;
while (curr != nullptr)
{
if (curr->key > k)
{
pre = curr;
curr = curr->left;
}
else if (curr->key < k)
{
pre = curr;
curr = curr->right;
}
else
{
if (curr->left == nullptr)
{
if (pre != nullptr && pre->right == curr)
{
Node* temp = curr;
curr = curr->right;
delete temp;
pre->right = curr;
sync_size(r);
return r;
}
else if (pre != nullptr && pre->left == curr)
{
Node* temp = curr;
curr = curr->right;
delete temp;
pre->left = curr;
sync_size(r);
return r;
}
else
{
Node* temp = curr;
curr = curr->right;
delete temp;
r = curr;
sync_size(r);
}
}
else if (curr->right == nullptr)
{
if (pre != nullptr && pre->right == curr)
{
Node* temp = curr;
curr = curr->left;
delete temp;
pre->right = curr;
sync_size(r);
return r;
}
else if (pre != nullptr && pre->left == curr)
{
Node* temp = curr;
curr = curr->left;
delete temp;
pre->left = curr;
sync_size(r);
return r;
}
else
{
Node* temp = curr;
curr = curr->left;
delete temp;
r = curr;
sync_size(r);
return r;
}
}
else
{
if (pre != nullptr)
{
Node* get = min(curr->right);
get->right = deleteMin(curr->right);
get->left = curr->left;
if (pre->right == curr)
pre->right = get;
else
pre->left = get;
delete curr;
sync_size(r);
return r;
}
else
{
Node* get = min(curr->right);
get->right = deleteMin(curr->right);
get->left = curr->left;
r = get;
delete curr;
sync_size(r);
return r;
}
}
}
}
}
/*******************************************
函数名称: floor
函数说明: 向下取整返回第一个key<= k的节点
返回值: Node*
********************************************/
Node* floor(Node* r, const Key& k)
{
if (r == nullptr)
return nullptr;
Node* pre = nullptr;
while (true)
{
if (r->key == k)
return r;
else if (r->key > k)
{
if (r->left == nullptr)
{
return pre;
}
else
r = r->left;
}
else
{
if (r->right == nullptr)
return r;
else
{
pre = r;
r = r->right;
}
}
}
}
public:
int size()
{
return size(root);
}
void put(const Key& k, const int& v)
{
root = put(root, k, v);
}
Value get(const Key& k)
{
return get(root, k);
}
Key min()
{
return min(root)->key;
}
Key max()
{
return max(root)->key;
}
void deleteMin()
{
root = deleteMin(root);
}
void deleteMax()
{
root = deleteMax(root);
}
void before_display()
{
before_display(root);
}
void erase(const Key& k)
{
root = erase(root, k);
}
Key floor(const Key& k)
{
Node* ret;
if ((ret = floor(root, k)) == nullptr)
throw std::out_of_range("can't floor");
else
return ret->key;
}
};
main.cpp
#include <iostream>
#include "BST.h"
using namespace std;
int main()
{
BST<double, int> bst;
for (int i = 0; i < 10; ++i)
bst.put(i + 0.1, i);
cout << "size: " << bst.size() << endl;
try
{
cout << bst.get(1.1) << endl;
bst.deleteMin();
bst.deleteMax();
bst.erase(3.1);
cout << "min: " << bst.min() << endl;
cout << "max: " << bst.max() << endl;
cout << bst.size() << endl;
cout << bst.floor(5.0) << endl;
}
catch (const exception& e)
{
cout << e.what();
}
bst.before_display();
system("pause");
return 0;
}
运行: