启动类
import threading
from p5 import *
from com.dkd.data.mysql.b_tree import b_tree
def setup():
size(1000, 800)
background(255)
global bt, font
bt = b_tree(m=3)
font = create_font(name="C:\Windows\Fonts\consola.ttf", size=16)
def draw():
background(255)
try:
bt.draw_tree(width=width, font=font)
except:
pass
# print("*****************************")
def input_num():
while True:
num = input("请输入数字:")
try:
bt.add_node(k=int(str(num).strip()))
except:
pass
if __name__ == '__main__':
t = threading.Thread(target=input_num)
t.start()
run()
b树工具类
'''
B 树 多路平衡查找树
'''
from p5 import *
class b_tree:
def __init__(self, m):
self.root = None
self.m = m
def add_node(self, k):
if self.root is None:
self.root = b_node()
self.root.add_key(k=k)
else:
'''
从上往下遍历
如果节点没有子节点 证明 是叶子节点 不能往下遍历 将数据加到叶子节点中
'''
cur_node = self.root
while True:
print("add...........")
if not cur_node.has_children():
cur_node.add_key(k=k)
break
else:
insert_index = len(cur_node.ks)
for i in range(0, len(cur_node.ks)):
if cur_node.ks[i] > k:
insert_index = i + 1
cur_node = cur_node.cs[insert_index]
self.check_keys(node=cur_node)
def check_keys(self, node):
if node.is_full(self.m):
index = int(self.m / 2)
top_k = node.ks[index]
left_ks = node.ks[0:index]
right_ks = node.ks[index + 1:]
right_node = b_node()
right_node.add_keys(ks=right_ks)
left_node = b_node()
left_node.add_keys(ks=left_ks)
'''
孩子节点的瓜分
'''
if node.has_children():
for c in node.cs:
if c.max_key() >= top_k:
right_node.add_child(child=c)
else:
left_node.add_child(child=c)
'''
root 和 非root
'''
if node.p is None:
self.root = b_node()
self.root.add_key(k=top_k)
self.root.add_children(children=[left_node, right_node])
else:
np = node.p
np.remove_child(child=node)
np.add_key(k=top_k)
np.add_children(children=[left_node, right_node])
self.check_keys(node=np)
def draw_tree(self, width, font):
if self.root is None:
return
basic_x = int(width / 2)
basic_y = 50
radius = 50
x_unit = radius * self.m
y_unit = 100
tree_height = self.tree_height(node=self.root)
all_leaf = self.find_all_leaf(node=self.root)
leaf_half_width = int(len(all_leaf) * x_unit / 2)
leaf_parent = []
for i in range(0, len(all_leaf)):
leaf = all_leaf[i]
leaf.x = basic_x + x_unit * i - leaf_half_width
leaf.y = basic_y + (tree_height - 1) * y_unit
if leaf.p is not None:
if leaf.p not in leaf_parent:
leaf_parent.append(leaf.p)
self.fill_xy(nodes=leaf_parent, y_unit=y_unit)
self.draw_node(node=self.root, font=font, radius=radius)
def tree_height(self, node):
if node is None:
return 0
if not node.has_children():
return 1
else:
max_height = 0
for c in node.cs:
child_height = self.tree_height(node=c)
if child_height > max_height:
max_height = child_height
return 1 + max_height
def find_all_leaf(self, node):
leaf = []
if node is None:
return leaf
if not node.has_children():
leaf.append(node)
return leaf
else:
for c in node.cs:
child_leaf = self.find_all_leaf(node=c)
if len(child_leaf) > 0:
for cf in child_leaf:
leaf.append(cf)
return leaf
def fill_xy(self, nodes, y_unit):
parent = []
for node in nodes:
node.x = node.cs[0].x
node.y = node.cs[0].y - y_unit
if len(node.cs) > 1:
node.x = int((node.x + node.cs[len(node.ks)].x) / 2)
if node.p is not None:
if node.p not in parent:
parent.append(node.p)
if len(parent) > 0:
self.fill_xy(nodes=parent, y_unit=y_unit)
def draw_node(self, node, font, radius):
stroke(0, 0, 255)
strokeWeight(1)
fill(0, 255, 150)
rect_mode(CENTER)
for i in range(0, len(node.ks)):
x = node.x + i * radius
y = node.y
rect(x, y, radius, radius)
noStroke()
fill(0)
textFont(font=font, size=20)
text_align(align_x=CENTER, align_y=CENTER)
for i in range(0, len(node.ks)):
x = node.x + i * radius
y = node.y
text(str(node.ks[i]), x, y)
if node.has_children():
for i in range(0, len(node.cs)):
child = node.cs[i]
noFill()
stroke(255, 0, 0)
strokeWeight(2)
half_radius = int(radius / 2)
line(node.x -half_radius + i * radius, node.y + half_radius, child.x, child.y - half_radius)
self.draw_node(node=child, font=font, radius=radius)
class b_node:
def __init__(self):
'''
ks 关键字数组
cs 孩子节点的数组
'''
self.ks = []
self.cs = []
self.deep = 1
self.p = None
def add_key(self, k):
self.ks.append(k)
self.ks = sorted(self.ks, key=lambda k: k)
def add_keys(self, ks):
for k in ks:
self.ks.append(k)
self.ks = sorted(self.ks, key=lambda k: k)
def has_children(self):
return len(self.cs) > 0
def is_full(self, m):
return len(self.ks) >= m
def add_child(self, child):
self.cs.append(child)
child.p = self
def add_children(self, children):
for child in children:
self.add_child(child=child)
def max_key(self):
return self.ks[len(self.ks) - 1]
def remove_child(self, child):
cs = []
for c in self.cs:
if c == child:
continue
else:
cs.append(c)
self.cs = cs