import networkx as nx
import matplotlib.pyplot as plt
from collections import defaultdict
class Node:
def __init__(self, val):
self.name = val
self.left = None
self.right = None
def create(data, idx=0):
if idx >= len(data):
return None
cur = Node(data[idx])
cur.left = create(data, idx * 2+1)
cur.right = create(data, idx * 2 + 2)
return cur
def draw(node): # 以某个节点为根画图
saw = defaultdict(int)
def create_graph(G, node, p_name, pos={}, x=0, y=0, layer=1):
if not node:
return
name = str(node.name)
saw[name] += 1
if name in saw.keys():
name += ' ' * saw[name]
G.add_edge(p_name, name)
pos[name] = (x, y)
l_x, l_y = x - 2 / 3 ** layer, y - 1
l_layer = layer + 1
create_graph(G, node.left, name, x=l_x, y=l_y, pos=pos, layer=l_layer)
r_x, r_y = x + 2 / 3 ** layer, y - 1
r_layer = layer + 1
create_graph(G, node.right,name, x=r_x, y=r_y, pos=pos, layer=r_layer)
return (G, pos)
graph = nx.DiGraph()
graph, pos = create_graph(graph, node, " ")
pos[" "] = (0, 0)
fig, ax = plt.subplots(figsize=(8, 10)) # 比例可以根据树的深度适当调节
nx.draw_networkx(graph, pos, ax=ax, node_size=1000)
plt.show()
if __name__ == "__main__":
bi_tree = ['hello', 'world', 'I', 'exist', 'because', 'I', 'think']
root = create(bi_tree)
draw(root)