还是抽时间把以前的代码优化了一下。为了省事,二叉搜索树继承了基本二叉树的类,平衡二叉树的类继承了二叉搜索树的类,然后基本二叉树binary_tree
的代码下面没有,但在文章《二叉树可视化 终端打印》中已经给出。如果有bug或者建议,欢迎指出交流。
对了,为了调试平衡二叉树,我在可视化中将每个结点的bf打印了出来,视情况可自行去掉。
bst.py
from binary_tree import BasicTree, Node
class BSTTree(BasicTree):
"""二叉搜索树"""
def add(self, value: int):
"""新增结点"""
if not self.root:
self.root = self.Node(value)
return self.root
node = self.Node(value)
next_node = self.root
while True:
if value < next_node.value:
child_node = next_node.left
direction = 'left'
elif value > next_node.value:
child_node = next_node.right
direction = 'right'
else: # 不允许重复结点
return next_node
if child_node:
next_node = child_node
else:
node.parent = next_node
node.direction = direction
setattr(next_node, direction, node)
break
return node
def get_node(self, value: int):
"""根据值查找结点"""
if not self.root:
return
next_node = self.root
while next_node:
if next_node.value == value:
return next_node
elif value < next_node.value:
next_node = next_node.left
else:
next_node = next_node.right
def delete_value(self, value):
"""删除输入值对应的结点"""
node = self.get_node(value)
if node:
node = self.delete(node)
return node
def delete(self, node: Node):
"""删除结点"""
# assert node is not None, 'delete node should not be None'
while True:
if node.left is not None:
delete_node = self.get_max(node.left)
node.value, delete_node.value = delete_node.value, node.value
elif node.right is not None:
delete_node = self.get_min(node.right)
node.value, delete_node.value = delete_node.value, node.value
else:
setattr(node.parent, node.direction, None)
return node
node = delete_node
def get_min(self, node: Node):
"""获取小于结点的最小值"""
while node.left:
node = node.left
return node
def get_max(self, node: Node):
"""获取大于结点的最大值"""
while node.right:
node = node.right
return node
if __name__ == '__main__':
data = [10, 5, 20, 0, 7, 15, 25, 12, 17, 22, 30]
# data = list(range(5))
# data = [5, 4, 6, 2, 8, 1, 3, 7, 9, 4.5]
tree = BSTTree()
for i in data:
tree.add(i)
print("添加结果:")
print(tree)
# 测试一
value = 12
tree.delete_value(value)
print(f"删除{value}结果: ")
print(tree)
value = 17
tree.delete_value(value)
print(f"删除{value}结果: ")
print(tree)
# 测试二
value = 20
tree.delete_value(value)
print(f"删除{value}结果: ")
print(tree)
# 测试三
value = 6
tree.add(6)
print(f"添加{value}结果:")
print(tree)
value = 10
tree.delete_value(value)
print(f"删除{value}结果: ")
print(tree)
avl.py
from bst import BSTTree
class Node:
def __init__(self, value: int):
self.value = value
self.bf = 0 # 左子树的深度减去右子树的深度
self.left = None # 其实应该改成 lchild 和 rchild
self.right = None
self.parent = None # 方便执行删除操作
self.direction = None
def __str__(self):
return f"{self.value}({self.bf})" # 打印bf值
# return f'{self.value}'
def __len__(self):
return len(str(self))
class AVLTree(BSTTree):
"""平衡二叉树"""
Node = Node
def add(self, value: int):
"""添加结点"""
if not self.root:
self.root = self.Node(value)
return self.root
node = super().add(value)
self.check_bf(node, "add")
return node
def delete(self, node: Node):
"""删除结点"""
if not node:
return
delete_node = super().delete(node)
self.check_bf(delete_node, "delete")
def check_bf(self, node: Node, change_type: str):
"""
node: 新增结点或者被删结点
change_type: add / delete / no_change
从新添加或者被删除的结点开始往上更正bf值, 以及检查每个父结点是否平衡
"""
parent = node.parent
if not parent:
return
if change_type == "add":
if node.direction == "left":
parent.bf += 1
else:
parent.bf -= 1
if parent.bf in {0, 2, -2}:
change_type = "no_change"
elif change_type == "delete":
if node.direction == "left":
parent.bf -= 1
else:
parent.bf += 1
if parent.bf in {-1, 1}:
change_type = "no_change"
else:
return
if parent.bf in {-2, 2}:
parent = self.balance_spin(parent)
if change_type == "add":
change_type = "no_change"
elif change_type == "delete":
if parent.bf != 0:
change_type = "no_change"
self.check_bf(parent, change_type)
def balance_spin(self, node: Node):
"""
node: 不平衡的结点
平衡算法,判断旋转方式,返回平衡后的父结点
"""
# 右边子树高
if node.bf == -2:
if node.right.bf == -1:
spin_node = node.right
node.bf = spin_node.bf = 0
self.spin(node=spin_node, spin_type="left")
elif node.right.bf == 1:
spin_node = node.right.left
if spin_node.bf == 0:
node.bf = node.right.bf = 0
elif spin_node.bf == -1:
node.bf = 1
spin_node.bf = node.right.bf = 0
elif spin_node.bf == 1:
node.right.bf = -1
spin_node.bf = node.bf = 0
self.spin(node=spin_node, spin_type="right")
self.spin(node=spin_node, spin_type="left")
elif node.right.bf == 0: # 删除节点时可能会出现这种情况
spin_node = node.right
node.bf = -1
spin_node.bf = 1
self.spin(node=spin_node, spin_type="left")
# 左边子树高
elif node.bf == 2:
if node.left.bf == 1:
spin_node = node.left
node.bf = node.left.bf = 0
self.spin(node=spin_node, spin_type="right")
elif node.left.bf == -1:
spin_node = node.left.right
if spin_node.bf == 0:
node.bf = node.left.bf = 0
elif spin_node.bf == -1:
node.left.bf = 1
spin_node.bf = node.bf = 0
elif spin_node.bf == 1:
node.bf = -1
spin_node.bf = node.left.bf = 0
self.spin(node=spin_node, spin_type="left")
self.spin(node=spin_node, spin_type="right")
elif node.left.bf == 0:
spin_node = node.left
node.bf = 1
spin_node.bf = -1
self.spin(node=spin_node, spin_type="right")
return spin_node
def spin(self, node: Node, spin_type: str):
"""旋转算法"""
parent_node = node.parent
if spin_type == "right":
parent_node.left = node.right
if node.right:
node.right.parent = parent_node
node.right.direction = "left"
node.right = parent_node
if parent_node.parent is None:
self.root = node
node.parent = None
else:
node.parent = parent_node.parent
node.direction = parent_node.direction
setattr(parent_node.parent, parent_node.direction, node)
parent_node.parent = node
parent_node.direction = "right"
elif spin_type == "left":
parent_node.right = node.left
if node.left:
node.left.parent = parent_node
node.left.direction = "right"
node.left = parent_node
if parent_node.parent is None:
self.root = node
node.parent = None
else:
node.parent = parent_node.parent
node.direction = parent_node.direction
setattr(parent_node.parent, parent_node.direction, node)
parent_node.parent = node
parent_node.direction = "left"
if __name__ == "__main__":
data = [50, 40, 80, 30, 45, 90, 20, 42, 47]
tree = AVLTree()
for i in data:
tree.add(i)
print("添加结果:")
print(tree)
# 测试 ll
tree.add(1)
print("测试ll, 添加1: ")
print(tree)
# 测试 rr
tree.add(100)
print("测试rr, 添加100: ")
print(tree)
# 测试lr
tree.add(70)
print("测试lr-1, 添加70: ")
print(tree)
tree.add(75)
print("测试lr-2, 添加75: ")
print(tree)
# 测试rl
node = tree.add(12)
print("测试rl-1, 添加12: ")
print(tree)
node = tree.add(11)
print("测试rl-2, 添加11: ")
print(tree)
# 测试删除
node = tree.add(13)
print("测试删除: rr-1, 添加13: ")
print(tree)
tree.delete_value(10)
print("测试删除: rr-2, 删除10: ")
print(tree)
node = tree.add(5)
print("测试删除: ll-1, 添加5: ")
print(tree)
tree.delete_value(3)
print("测试删除: ll-2, 删除3: ")
print(tree)
tree.delete_value(5)
tree.delete_value(42)
tree.delete_value(45)
print("测试删除: lr-1, 删除5, 42, 45: ")
print(tree)
tree.delete_value(47)
print("测试删除: lr-2, 删除47: ")
print(tree)
tree.delete_value(13)
tree.delete_value(70)
tree.delete_value(80)
print("测试删除: rl-1, 删除13, 70, 80: ")
print(tree)
tree.delete_value(75)
print("测试删除: rl-2, 删除75: ")
print(tree)
# 测试任意删除
tree.delete_value(50)
print("测试删除任意结点50: ")
print(tree)
tree.delete_value(40)
print("测试删除任意结点40: ")
print(tree)
tree.delete_value(1)
print("测试删除任意结点1: ")
print(tree)