#include <iostream>
class TreeNode {
public:
int val;
TreeNode* left;
TreeNode* right;
TreeNode(int val) : val(val), left(nullptr), right(nullptr) {};
};
class BST {
public:
BST() : root(nullptr) {}
void insert(int val) {
root = insertRecursive(root, val);
}
bool search(int val) {
return searchRecursive(root, val);
}
void remove(int val) {
root = removeRecursive(root, val);
}
private:
TreeNode* root;
TreeNode* insertRecursive(TreeNode* node, int val) {
if (node == nullptr) {
return new TreeNode(val);
}
if (val < node->val) {
node->left = insertRecursive(node->left, val);
}
else if (val > root->val) {
node->right = insertRecursive(node->right, val);
}
return node; // 如果已经存在的话,也是返回node
}
bool searchRecursive(TreeNode* node, int val) {
if (node == nullptr) return false;
if (node->val == val) return true;
if (val < node->val) {
return searchRecursive(node->left, val);
}
else {
return searchRecursive(node->right, val);
}
}
TreeNode* removeRecursive(TreeNode* node, int val) {
if (node == nullptr) {
return node;
}
if (val < node->val) {
node->left = removeRecursive(node->left, val);
}
else if (val > node->val) {
node->right = removeRecursive(node->right, val);
}
else {
// 左子树为空,把右子树接上
if (node->left == nullptr) {
TreeNode* temp = node->right;
delete node;
return temp;
}
// 右子树为空,把左子树接上
else if (node->right == nullptr) {
TreeNode* temp = node->left;
delete node;
return temp;
}
// 左右子树都不空,就找到右子树中最小的值,把值赋给node,并将这个最小的值移除
TreeNode* minRightSbutree = findMin(node->right);
// 将
node->val = minRightSbutree->val;
node->right = removeRecursive(node->right, minRightSbutree->val);
}
return node;
}
TreeNode* findMin(TreeNode* node) {
while (node->left != nullptr) {
node = node->left;
}
return node;
}
};
int main() {
BST bst;
bst.insert(50);
bst.insert(30);
bst.insert(70);
bst.insert(20);
bst.insert(40);
bst.insert(60);
bst.insert(80);
std::cout << "Searching for 40: " << (bst.search(40) ? "Found" : "Not found") << std::endl;
std::cout << "Searching for 90: " << (bst.search(90) ? "Found" : "Not found") << std::endl;
bst.remove(30);
std::cout << "Searching for 30 after removal: " << (bst.search(30) ? "Found" : "Not found") << std::endl;
return 0;
}