二叉搜索树 (c++非递归版)

400多行代码(其实还有好多API没完成),请接受恐惧的代码吧得到节点count是递归的(为了不太复杂),下面也给出了测试代码

BST.h

#pragma once
#include <stdexcept>
#include <stack>
template<typename Key, typename Value>
class BST
{
private:
    class Node
    {
    public:
        Node* left = nullptr;
        Node* right = nullptr;
        Key key;
        Value value;
        int count = 0;
    public:
        Node(const Key& k, const Value& v, const int& c) :key(k), value(v), count(c)
        {
        }
    };
    Node* root = nullptr;
private:
/********************************************
函数名称: size
函数说明: 返回二叉树节点个数
返回值:   int
********************************************/
    int size(Node* r)
    {
        if (r == nullptr)
            return 0;
        return r->count;
    }
/******************************************
函数名称:  sync_size
函数说明:  更新各节点的count值
返回值:    void
*******************************************/
    void sync_size(Node* r)
    {
        if (r == nullptr)
            return;
        sync_size(r->left);
        sync_size(r->right);
        r->count = size(r->left) + size(r->right) + 1;
        return;
    }
/*******************************************
函数名称: put
函数说明: 存储节点于二叉搜索树
返回值:   Node*
*******************************************/
    Node* put(Node* r, const Key& k, const Value& v)
    {
        if (r == nullptr)
            return new Node(k, v, 1);
        Node* curr = r;
        while (true)
        {
            if (curr->key > k)
            {
                if (curr->left == nullptr)
                {
                    curr->left = new Node(k, v, 1);
                    break;
                }
                else
                    curr = curr->left;
            }
            else if (curr->key < k)
            {
                if (curr->right == nullptr)
                {
                    curr->right = new Node(k, v, 1);
                    break;
                }
                else
                    curr = curr->right;
            }
            else
            {
                curr->value = v;
                return r;
            }
        }
        sync_size(r);
        return r;
    }
/*******************************************
函数名称:  get
函数说明:  得到二叉树key对应value
返回值:    Value
******************************************/
    Value get(Node* r, const Key& key)
    {
        if (r == nullptr)
            throw std::out_of_range("can't get the value of key");
        while (true)
        {
            if (r->key > key)
            {
                if (r->left == nullptr)
                    throw  std::out_of_range("can't get the value of key");
                else
                    r = r->left;
            }
            else if (r->key < key)
            {
                if (r->right == nullptr)
                    throw  std::out_of_range("can't get the value of key");
                else
                    r = r->right;
            }
            else
                return r->value;
        }
    }
/******************************************
函数名称: min
函数说明: 取得二叉树key最小的节点
返回值:   Node*
********************************************/
    Node* min(Node* r)
    {
        if (r == nullptr)
            throw std::out_of_range("can't gain the min");
        while (true)
        {
            if (r->left == nullptr)
                return r;
            else
                r = r->left;
        }
    }
/******************************************
函数名称: max
函数说明: 取得二叉树里key最大的节点
返回值:   Node*
*******************************************/
    Node* max(Node* r)
    {
        if (r == nullptr)
            throw std::out_of_range("can't gain the max");
        while (true)
        {
            if (r->right == nullptr)
                return r;
            else
                r = r->right;
        }
    }
/******************************************
函数名称: deleteMin
函数说明: 删除二叉树中key最小的节点
返回值:   Node*
*******************************************/
    Node* deleteMin(Node *r)
    {
        if (r == nullptr)
            return nullptr;
        Node* curr = r;
        Node* pre = nullptr;
        while (true)
        {
            if (curr->left == nullptr)
            {
                Node* temp = curr;
                curr = curr->right;
                delete temp;
                if (pre != nullptr)
                    pre->left = curr;
                else
                    r = curr;
                sync_size(r);
                return r;
            }
            else
            {
                pre = curr;
                curr = curr->left;
            }
        }
    }
/******************************************
函数名称: deleteMax
函数说明: 删除二叉树key最大的节点
返回值:   void
*******************************************/
    Node* deleteMax(Node *r)
    {
        if (r == nullptr)
            return nullptr;
        Node* curr = r;
        Node* pre = nullptr;
        while (true)
        {
            if (curr->right == nullptr)
            {
                Node* temp = curr;
                curr = curr->left;
                delete temp;
                if (pre != nullptr)
                    pre->right = curr;
                else
                    r = curr;
                sync_size(r);
                return r;
            }
            else
            {
                pre = curr;
                curr = curr->right;
            }
        }
    }
/******************************************
函数名称: before_display
函数说明: 中序遍历二叉树打印节点value
返回值:   void
*******************************************/
    void before_display(Node* r)
    {
        if (r == nullptr)
            return;
        stack<Node*> s;
        while (r != nullptr || !s.empty())
        {
            if (r != nullptr)
            {
                s.push(r);
                r = r->left;
            }
            else
            {
                r = s.top();
                cout << r->value << ends;
                s.pop();
                r = r->right;
            }
        }
    }
/******************************************
函数名称:  erase
函数说明:  删除特定key的节点
*******************************************/
    Node* erase(Node* r, const Key& k)
    {
        if (r == nullptr)
            return nullptr;
        Node* curr = r;
        Node* pre = nullptr;
        while (curr != nullptr)
        {
            if (curr->key > k)
            {
                pre = curr;
                curr = curr->left;
            }
            else if (curr->key < k)
            {
                pre = curr;
                curr = curr->right;
            }
            else
            {
                if (curr->left == nullptr)
                {
                    if (pre != nullptr && pre->right == curr)
                    {
                        Node* temp = curr;
                        curr = curr->right;
                        delete temp;
                        pre->right = curr;
                        sync_size(r);
                        return r;
                    }
                    else if (pre != nullptr && pre->left == curr)
                    {
                        Node* temp = curr;
                        curr = curr->right;
                        delete temp;
                        pre->left = curr;
                        sync_size(r);
                        return r;
                    }
                    else
                    {
                        Node* temp = curr;
                        curr = curr->right;
                        delete temp;
                        r = curr;
                        sync_size(r);
                    }
                }
                else if (curr->right == nullptr)
                {
                    if (pre != nullptr && pre->right == curr)
                    {
                        Node* temp = curr;
                        curr = curr->left;
                        delete temp;
                        pre->right = curr;
                        sync_size(r);
                        return r;
                    }
                    else if (pre != nullptr && pre->left == curr)
                    {
                        Node* temp = curr;
                        curr = curr->left;
                        delete temp;
                        pre->left = curr;
                        sync_size(r);
                        return r;
                    }
                    else
                    {
                        Node* temp = curr;
                        curr = curr->left;
                        delete temp;
                        r = curr;
                        sync_size(r);
                        return r;
                    }
                }
                else
                {
                    if (pre != nullptr)
                    {
                        Node* get = min(curr->right);
                        get->right = deleteMin(curr->right);
                        get->left = curr->left;
                        if (pre->right == curr)
                            pre->right = get;
                        else
                            pre->left = get;
                        delete curr;
                        sync_size(r);
                        return r;
                    }
                    else
                    {
                        Node* get = min(curr->right);
                        get->right = deleteMin(curr->right);
                        get->left = curr->left;
                        r = get;
                        delete curr;
                        sync_size(r);
                        return r;
                    }
                }
            }
        }
    }
/*******************************************
函数名称:  floor
函数说明:  向下取整返回第一个key<= k的节点
返回值:    Node*
********************************************/
    Node* floor(Node* r, const Key& k)
    {
        if (r == nullptr)
            return nullptr;
        Node* pre = nullptr;
        while (true)
        {
            if (r->key == k)
                return r;
            else if (r->key > k)
            {
                if (r->left == nullptr)
                {
                    return pre;
                }
                else
                    r = r->left;
            }
            else
            {
                if (r->right == nullptr)
                    return r;
                else
                {
                    pre = r;
                    r = r->right;
                }
            }
        }
    }
public:
    int size()
    {
        return size(root);
    }
    void put(const Key& k, const int& v)
    {
        root = put(root, k, v);
    }
    Value get(const Key& k)
    {
        return get(root, k);
    }
    Key min()
    {
        return min(root)->key;
    }
    Key max()
    {
        return max(root)->key;
    }
    void deleteMin()
    {
        root = deleteMin(root);
    }
    void deleteMax()
    {
        root = deleteMax(root);
    }
    void before_display()
    {
        before_display(root);
    }
    void erase(const Key& k)
    {
        root = erase(root, k);
    }
    Key floor(const Key& k)
    {
        Node* ret;
        if ((ret = floor(root, k)) == nullptr)
            throw std::out_of_range("can't floor");
        else
            return ret->key;
    }
};

main.cpp

#include <iostream>
#include "BST.h"
using namespace std;

int main()
{
    BST<double, int> bst;
    for (int i = 0; i < 10; ++i)
        bst.put(i + 0.1, i);
    cout << "size: " << bst.size() << endl;
    try
    {
        cout << bst.get(1.1) << endl;
        bst.deleteMin();
        bst.deleteMax();
        bst.erase(3.1);
        cout << "min: " << bst.min() << endl;
        cout << "max: " << bst.max() << endl;
        cout << bst.size() << endl;
        cout << bst.floor(5.0) << endl;
    }
    catch (const exception& e)
    {
        cout << e.what();
    }
    bst.before_display();
    system("pause");
    return 0;
}

运行:

这里写图片描述

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值