python实现二叉搜索树和平衡二叉树

还是抽时间把以前的代码优化了一下。为了省事,二叉搜索树继承了基本二叉树的类,平衡二叉树的类继承了二叉搜索树的类,然后基本二叉树binary_tree的代码下面没有,但在文章《二叉树可视化 终端打印》中已经给出。如果有bug或者建议,欢迎指出交流。

对了,为了调试平衡二叉树,我在可视化中将每个结点的bf打印了出来,视情况可自行去掉。

bst.py

from binary_tree import BasicTree, Node


class BSTTree(BasicTree):
    """二叉搜索树"""

    def add(self, value: int):
        """新增结点"""
        if not self.root:
            self.root = self.Node(value)
            return self.root

        node = self.Node(value)
        next_node = self.root
        while True:
            if value < next_node.value:
                child_node = next_node.left
                direction = 'left'
            elif value > next_node.value:
                child_node = next_node.right
                direction = 'right'
            else:  # 不允许重复结点
                return next_node
            if child_node:
                next_node = child_node
            else:
                node.parent = next_node
                node.direction = direction
                setattr(next_node, direction, node)
                break
        return node

    def get_node(self, value: int):
        """根据值查找结点"""
        if not self.root:
            return

        next_node = self.root
        while next_node:
            if next_node.value == value:
                return next_node
            elif value < next_node.value:
                next_node = next_node.left
            else:
                next_node = next_node.right

    def delete_value(self, value):
        """删除输入值对应的结点"""
        node = self.get_node(value)
        if node:
            node = self.delete(node)
        return node

    def delete(self, node: Node):
        """删除结点"""
        # assert node is not None, 'delete node should not be None'

        while True:
            if node.left is not None:
                delete_node = self.get_max(node.left)
                node.value, delete_node.value = delete_node.value, node.value
            elif node.right is not None:
                delete_node = self.get_min(node.right)
                node.value, delete_node.value = delete_node.value, node.value
            else:
                setattr(node.parent, node.direction, None)
                return node
            node = delete_node

    def get_min(self, node: Node):
        """获取小于结点的最小值"""
        while node.left:
            node = node.left
        return node

    def get_max(self, node: Node):
        """获取大于结点的最大值"""
        while node.right:
            node = node.right
        return node


if __name__ == '__main__':
    data = [10, 5, 20, 0, 7, 15, 25, 12, 17, 22, 30]
    # data = list(range(5))
    # data = [5, 4, 6, 2, 8, 1, 3, 7, 9, 4.5]
    tree = BSTTree()
    for i in data:
        tree.add(i)
    print("添加结果:")
    print(tree)
    # 测试一
    value = 12
    tree.delete_value(value)
    print(f"删除{value}结果: ")
    print(tree)
    value = 17
    tree.delete_value(value)
    print(f"删除{value}结果: ")
    print(tree)
    # 测试二
    value = 20
    tree.delete_value(value)
    print(f"删除{value}结果: ")
    print(tree)
    # 测试三
    value = 6
    tree.add(6)
    print(f"添加{value}结果:")
    print(tree)
    value = 10
    tree.delete_value(value)
    print(f"删除{value}结果: ")
    print(tree)

avl.py

from bst import BSTTree


class Node:
    def __init__(self, value: int):
        self.value = value
        self.bf = 0  # 左子树的深度减去右子树的深度
        self.left = None  # 其实应该改成 lchild 和 rchild
        self.right = None
        self.parent = None  # 方便执行删除操作
        self.direction = None

    def __str__(self):
        return f"{self.value}({self.bf})"  # 打印bf值
        # return f'{self.value}'

    def __len__(self):
        return len(str(self))


class AVLTree(BSTTree):
    """平衡二叉树"""

    Node = Node

    def add(self, value: int):
        """添加结点"""
        if not self.root:
            self.root = self.Node(value)
            return self.root

        node = super().add(value)
        self.check_bf(node, "add")
        return node

    def delete(self, node: Node):
        """删除结点"""
        if not node:
            return
        delete_node = super().delete(node)
        self.check_bf(delete_node, "delete")

    def check_bf(self, node: Node, change_type: str):
        """
        node: 新增结点或者被删结点
        change_type: add / delete / no_change
        从新添加或者被删除的结点开始往上更正bf值, 以及检查每个父结点是否平衡
        """
        parent = node.parent
        if not parent:
            return

        if change_type == "add":
            if node.direction == "left":
                parent.bf += 1
            else:
                parent.bf -= 1
            if parent.bf in {0, 2, -2}:
                change_type = "no_change"
        elif change_type == "delete":
            if node.direction == "left":
                parent.bf -= 1
            else:
                parent.bf += 1
            if parent.bf in {-1, 1}:
                change_type = "no_change"
        else:
            return
        if parent.bf in {-2, 2}:
            parent = self.balance_spin(parent)
            if change_type == "add":
                change_type = "no_change"
            elif change_type == "delete":
                if parent.bf != 0:
                    change_type = "no_change"
        self.check_bf(parent, change_type)

    def balance_spin(self, node: Node):
        """
        node: 不平衡的结点
        平衡算法,判断旋转方式,返回平衡后的父结点
        """
        # 右边子树高
        if node.bf == -2:
            if node.right.bf == -1:
                spin_node = node.right
                node.bf = spin_node.bf = 0
                self.spin(node=spin_node, spin_type="left")
            elif node.right.bf == 1:
                spin_node = node.right.left
                if spin_node.bf == 0:
                    node.bf = node.right.bf = 0
                elif spin_node.bf == -1:
                    node.bf = 1
                    spin_node.bf = node.right.bf = 0
                elif spin_node.bf == 1:
                    node.right.bf = -1
                    spin_node.bf = node.bf = 0
                self.spin(node=spin_node, spin_type="right")
                self.spin(node=spin_node, spin_type="left")
            elif node.right.bf == 0:  # 删除节点时可能会出现这种情况
                spin_node = node.right
                node.bf = -1
                spin_node.bf = 1
                self.spin(node=spin_node, spin_type="left")
        # 左边子树高
        elif node.bf == 2:
            if node.left.bf == 1:
                spin_node = node.left
                node.bf = node.left.bf = 0
                self.spin(node=spin_node, spin_type="right")
            elif node.left.bf == -1:
                spin_node = node.left.right
                if spin_node.bf == 0:
                    node.bf = node.left.bf = 0
                elif spin_node.bf == -1:
                    node.left.bf = 1
                    spin_node.bf = node.bf = 0
                elif spin_node.bf == 1:
                    node.bf = -1
                    spin_node.bf = node.left.bf = 0
                self.spin(node=spin_node, spin_type="left")
                self.spin(node=spin_node, spin_type="right")
            elif node.left.bf == 0:
                spin_node = node.left
                node.bf = 1
                spin_node.bf = -1
                self.spin(node=spin_node, spin_type="right")
        return spin_node

    def spin(self, node: Node, spin_type: str):
        """旋转算法"""
        parent_node = node.parent

        if spin_type == "right":
            parent_node.left = node.right
            if node.right:
                node.right.parent = parent_node
                node.right.direction = "left"
            node.right = parent_node

            if parent_node.parent is None:
                self.root = node
                node.parent = None
            else:
                node.parent = parent_node.parent
                node.direction = parent_node.direction
                setattr(parent_node.parent, parent_node.direction, node)
            parent_node.parent = node
            parent_node.direction = "right"
        elif spin_type == "left":
            parent_node.right = node.left
            if node.left:
                node.left.parent = parent_node
                node.left.direction = "right"
            node.left = parent_node

            if parent_node.parent is None:
                self.root = node
                node.parent = None
            else:
                node.parent = parent_node.parent
                node.direction = parent_node.direction
                setattr(parent_node.parent, parent_node.direction, node)
            parent_node.parent = node
            parent_node.direction = "left"


if __name__ == "__main__":
    data = [50, 40, 80, 30, 45, 90, 20, 42, 47]
    tree = AVLTree()
    for i in data:
        tree.add(i)
    print("添加结果:")
    print(tree)

    # 测试 ll
    tree.add(1)
    print("测试ll, 添加1: ")
    print(tree)
    # 测试 rr
    tree.add(100)
    print("测试rr, 添加100: ")
    print(tree)
    # 测试lr
    tree.add(70)
    print("测试lr-1, 添加70: ")
    print(tree)
    tree.add(75)
    print("测试lr-2, 添加75: ")
    print(tree)
    # 测试rl
    node = tree.add(12)
    print("测试rl-1, 添加12: ")
    print(tree)
    node = tree.add(11)
    print("测试rl-2, 添加11: ")
    print(tree)

    # 测试删除

    node = tree.add(13)
    print("测试删除: rr-1, 添加13: ")
    print(tree)
    tree.delete_value(10)
    print("测试删除: rr-2, 删除10: ")
    print(tree)

    node = tree.add(5)
    print("测试删除: ll-1, 添加5: ")
    print(tree)
    tree.delete_value(3)
    print("测试删除: ll-2, 删除3: ")
    print(tree)

    tree.delete_value(5)
    tree.delete_value(42)
    tree.delete_value(45)
    print("测试删除: lr-1, 删除5, 42, 45: ")
    print(tree)

    tree.delete_value(47)
    print("测试删除: lr-2, 删除47: ")
    print(tree)

    tree.delete_value(13)
    tree.delete_value(70)
    tree.delete_value(80)
    print("测试删除: rl-1, 删除13, 70, 80: ")
    print(tree)
    tree.delete_value(75)
    print("测试删除: rl-2, 删除75: ")
    print(tree)

    # 测试任意删除
    tree.delete_value(50)
    print("测试删除任意结点50: ")
    print(tree)
    tree.delete_value(40)
    print("测试删除任意结点40: ")
    print(tree)
    tree.delete_value(1)
    print("测试删除任意结点1: ")
    print(tree)
  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值