学习数据结构与算法分析加深记忆写的, 代码写的不太好 见谅..
template <class T>
class bst_data
{
public:
bst_data(){}
~bst_data(){}
bst_data<T>* m_right = nullptr;
bst_data<T>* m_left = nullptr;
T* m_value = nullptr;
};
template <class T>
class binary_search_tree
{
public:
binary_search_tree(){}
~binary_search_tree(){}
bst_data<T>* get_root() { return m_root_node; }
bst_data<T>* create_node()
{
bst_data<T> * node = new bst_data<T>;
node->m_value = new T;
++m_size;
return node;
}
void delete_node(bst_data<T> * node)
{
if (node)
{
delete node->m_value;
delete node;
}
--m_size;
}
bst_data<T>* create_tree(const T & value)
{
m_root_node = create_node();
*m_root_node->m_value = value;
return m_root_node;
}
bst_data<T>* clear(bst_data<T>* node)
{
if (node == nullptr)
return nullptr;
node->m_left = clear(node->m_left);
node->m_right= clear(node->m_right);
delete_node(node);
node = nullptr;
return nullptr;
}
size_t size()
{
return m_size;
}
public:
T* find_min(bst_data<T> * node =nullptr)
{
node = node == nullptr ? m_root_node : node;
return find_min_ex(node)->m_value;
}
T* find_max(bst_data<T> * node =nullptr)
{
node = node == nullptr ? m_root_node : node;
return find_max_ex(node)->m_value;
}
T* find(const T &value, bst_data<T>* node )
{
if (node == nullptr)
return nullptr;
if (value > *node->m_value)
return find(value,node->m_right);
return value<*node->m_value ? find(value, node->m_left) : node->m_value;
}
bst_data<T>* insert(const T &value, bst_data<T> * node )
{
if (node == nullptr)
{
auto element = create_node();
*element->m_value = value;
return element;
}
if (value > *node->m_value)
node->m_right = insert(value, node->m_right);
else if (value < *node->m_value)
node->m_left = insert(value, node->m_left);
else
return nullptr;
return node;
}
bst_data<T>* remove(const T&value, bst_data<T>* node )
{
if (value > *node->m_value)
node->m_right = remove(value, node->m_right);
else if (value < *node->m_value)
node->m_left = remove(value, node->m_left);
else if (node->m_left &&node->m_right)
{
bst_data<T>* temp = find_min_ex(node);
T temp_value = *temp->m_value;
*node->m_value = *temp->m_value;
node->m_right = remove(temp_value, node->m_right);
}
else
{
bst_data<T>* temp = node;
if (node->m_left == nullptr)
node = node->m_right;
else if (node->m_right == nullptr)
node = node->m_left;
delete_node(temp);
}
return node;
}
void for_each_element(bst_data<T>* node, std::function<void(T&)> func)
{
if (node == nullptr)
return;
func(*node->m_value);
for_each_element(node->m_left,func);
for_each_element(node->m_right,func);
}
private:
bst_data<T>* find_min_ex(bst_data<T> * node)
{
return node->m_left == nullptr ? node : find_min_ex(node->m_left);
}
bst_data<T>* find_max_ex(bst_data<T> * node)
{
return node->m_right == nullptr ? node : find_max_ex(node->m_right);
}
private:
bst_data<T>* m_root_node;
size_t m_size = 0;
};
int main()
{
binary_search_tree<int> tree;
tree.create_tree(1);
tree.insert(0, tree.get_root());
tree.insert(1, tree.get_root());
tree.insert(2, tree.get_root());
tree.insert(3, tree.get_root());
cout << * tree.find_min() << endl;
tree.for_each_element(tree.get_root(), [](int &v){
cout << v << endl;
});
return 0;
}