头文件:
#pragma once
#ifndef BST_H
#define BST_H
namespace CPP {
template<typename K, typename V>
class TreeNode
{
public:
K key;
V val;
TreeNode* left;
TreeNode* right;
TreeNode(K k, V v, TreeNode* left = nullptr, TreeNode* right = nullptr) :
key(k), val(v), left(left), right(right)
{}
};
template<typename K, typename V>
class BST
{
using Node = TreeNode<K, V>;
int count;
Node* root;
Node* insert(Node* node, K key, V val)
{
if (node == 0)
{
count++;
return new Node(key, val);
}
if (key == node->key)
{
return node;
}
if (key < node->key)
{
node->left = insert(node->left, key, val);
}
else
{
node->right = insert(node->right, key, val);
}
return node;
}
Node* findMin(Node* node)
{
if (node->left == nullptr)
{
return node;
}
return findMin(node->left);
}
Node* findMax(Node* node)
{
if (node->right == nullptr)
{
return node;
}
return findMax(node->right);
}
Node* popMin(Node* node)
{
if (node->left == nullptr)
{
Node* temp = node->right;
delete node;
count--;
return temp;
}
node->left = popMin(node->left);
return node;
}
Node* popMax(Node* node)
{
if (node->right == nullptr)
{
Node* temp = node->left;
delete node;
count--;
return temp;
}
node->right = popMax(node->right);
return node;
}
Node* find(Node* node, K& key)
{
if (node == nullptr)
{
return nullptr;
}
if (key == node->key)
{
return node;
}
if (key < node->key)
{
return find(node->left, key);
}
return find(node->right, key);
}
public:
BST()
{
count = 0;
root = nullptr;
}
int size()
{
return count;
}
bool empty()
{
return count == 0;
}
void insert(K key, V val)
{
root = insert(root,key, val);
}
Node* findMin()
{
return findMin(root);
}
Node* findMax()
{
return findMax(root);
}
void popMin()
{
if (root)
{
root=popMin(root);
}
}
void popMax()
{
if (root)
{
root = popMax(root);
}
}
Node* find(K key)
{
return find(root,key);
}
V operator[](K key)
{
return find(root, key)->val;
}
};
}
#endif
main:
#include <iostream>
#include <string>
#include "BST.h"
using namespace std;
using namespace CPP;
void test()
{
BST<int,string> bst;
bst.insert(5,"Alice");
bst.insert(10,"Bob");
bst.insert(12,"Coc");
bst.insert(12, "DDD");
cout << bst.findMin()->key << ":" << bst.findMin()->val << endl;
cout << bst.findMax()->key << ":" << bst.findMax()->val << endl;
auto n2 = bst.find(12);
cout << n2->key << ":" << n2->val << endl;
cout << "bst[10]:"<<bst[10] << endl;
cout << "____________________" << endl;
while (!bst.empty())
{
auto n3 = bst.findMin();
cout << n3->key << ":" << n3->val << endl;
bst.popMin();
}
bst.insert(5, "Alice");
bst.insert(10, "Bob");
bst.insert(12, "Coc");
cout << "$$$$$$$$$$$$$$$$$$" << endl;
while (!bst.empty())
{
auto n3 = bst.findMax();
cout << n3->key << ":" << n3->val << endl;
bst.popMax();
}
}
int main()
{
test();
return 0;
}
结果: