import matplotlib.pyplot as plt
import networkx as nx
import matplotlib.patches as patches
import torch
causal_color = '#F08080'
# 创建两个图的节点数据
ori_nodes_1 = torch.tensor([ 996, 68, 1305, 1305, 1305, 1305, 1305, 8046, 2654, 8020, 2138, 8013,
8011, 281, 1305, 1305, 1552, 6368, 6368, 6368, 6368, 6368, 935, 935,
5209], device='cuda:0', dtype=torch.int32)
causal_nodes_1 = torch.tensor([1305, 5209, 935, 68], device='cuda:0', dtype=torch.int32)
ori_nodes_2 = torch.tensor([41, 1296, 1296, 1296, 1296, 1296, 7988, 1305, 1296, 1296, 1296, 1296,
1296, 1305, 2353, 68, 1296, 1296, 1296, 1305, 2353, 1305, 1305, 79,
79], device='cuda:0', dtype=torch.int32)
causal_nodes_2 = torch.tensor([68, 1305, 1305, 1296], device='cuda:0', dtype=torch.int32)
causal_color = '#F08080'
# 转换为列表
ori_nodes_1 = ori_nodes_1.tolist()
causal_nodes_1 = causal_nodes_1.tolist()
ori_nodes_2 = ori_nodes_2.tolist()
causal_nodes_2 = causal_nodes_2.tolist()
# 定义中心节点
center_node1 = ori_nodes_1[0]
other_nodes_1 = ori_nodes_1[1:]
center_node2 = ori_nodes_2[0]
other_nodes_2 = ori_nodes_2[1:]
# 创建多重图
G = nx.MultiGraph()
# 添加节点和边
G.add_node(center_node1)
G.add_nodes_from(other_nodes_1)
G.add_node(center_node2)
G.add_nodes_from(other_nodes_2)
G.add_edges_from([(center_node1, other_node) for other_node in other_nodes_1],color='lightgray',linestyle ='-',width=3)
G.add_edges_from([(center_node2, other_node) for other_node in other_nodes_2],color='lightgray',linestyle ='-',width=3)
# 将因果边设为红色
G.add_edges_from([(center_node1, other_node) for other_node in causal_nodes_1],color=causal_color,linestyle ='-',width=3)
G.add_edges_from([(center_node2, other_node) for other_node in causal_nodes_2],color=causal_color,linestyle ='-',width=3)
# 设置图的属性
# 为每个节点指定颜色
node_color = []
for node in G.nodes():
if node == center_node1 or node == center_node2:
node_color.append(causal_color) # 根节点的颜色用绿色
elif node in causal_nodes_1: # 图一的因果节点用红色
node_color.append(causal_color)
elif node in causal_nodes_2: # 图二的因果节点用红色
node_color.append(causal_color)
else:
node_color.append('lightgray') # 两者的非因果节点用灰色
# 获取边颜色
edge_colors = [G[u][v][key]['color'] for u, v, key in G.edges(keys=True)]
# # 获取边类型
edge_styles = [G[u][v][key]['linestyle'] for u, v, key in G.edges(keys=True)]
# 设置边的宽度
edge_widths = [G[u][v][key]['width'] for u, v, key in G.edges(keys=True)]
# 去掉自环
self_loops = list(nx.selfloop_edges(G))
G.remove_edges_from(self_loops)
# 绘制所有边(不包括置顶边)
def draw_curved_edges(G, pos, color='lightgray'):
for u, v, key in G.edges(keys=True) :
if (u, v,key) != (center_node1,center_node2,1) and \
(u,v,key) not in [(center_node1, other_node,1) for other_node in causal_nodes_1] and\
(u,v,key) not in [(center_node2, other_node,1) for other_node in causal_nodes_2]: # 不绘制特定的边
rad = 0.1 * (key + 1)
edge = patches.FancyArrowPatch(
posA=pos[u],
posB=pos[v],
connectionstyle=f"arc3,rad={rad}",
color=color,
arrowstyle='->',
mutation_scale=10,
lw=3,
zorder=0 # 设置较低的z-order
)
plt.gca().add_patch(edge)
# 绘制置顶边
def draw_highlighted_edge(G, pos, color='red'):
for u, v, key in G.edges(keys=True) :
if (u, v,key) != (center_node1,center_node2,1) and ((u,v,key) in [(center_node1, other_node,1) for other_node in causal_nodes_1] or\
(u,v,key) in [(center_node2, other_node,1) for other_node in causal_nodes_2]): # 不绘制特定的边
rad = 0.1 * (key + 1)
edge = patches.FancyArrowPatch(
posA=pos[u],
posB=pos[v],
connectionstyle=f'arc3,rad={rad}',
color=color,
arrowstyle='->',
mutation_scale=10,
lw=3, # 更粗的边线宽度
zorder=2 # 设置较高的z-order
)
plt.gca().add_patch(edge)
u,v = center_node1,center_node2
edge = patches.FancyArrowPatch(
posA=pos[u],
posB=pos[v],
connectionstyle=f'arc3,rad=0.5',
color=color,
arrowstyle='->',
mutation_scale=10,
lw=3, # 更粗的边线宽度
linestyle='--',
zorder=2 # 设置较高的z-order
)
plt.gca().add_patch(edge)
# 绘制图
pos = nx.spring_layout(G,seed=666) # 使用 Fruchterman-Reingold 布局算法
plt.figure(figsize=(12, 12))
nx.draw(G, pos, with_labels=True, node_size=2500, node_color=node_color, style=edge_styles,\
edge_color = edge_colors, width=edge_widths, font_size=15, font_weight='bold')
# 绘制弯曲的边
draw_curved_edges(G, pos, color='lightgray')
# 绘制置顶边
draw_highlighted_edge(G, pos, color=causal_color)
# 保存图像
plt.savefig('example.png', format='png', bbox_inches='tight')
networkx画图
最新推荐文章于 2024-08-10 23:14:03 发布