networkx画图

import matplotlib.pyplot as plt
import networkx as nx
import matplotlib.patches as patches
import torch
causal_color = '#F08080'

# 创建两个图的节点数据




ori_nodes_1 = torch.tensor([ 996,   68, 1305, 1305, 1305, 1305, 1305, 8046, 2654, 8020, 2138, 8013,
        8011,  281, 1305, 1305, 1552, 6368, 6368, 6368, 6368, 6368,  935,  935,
        5209], device='cuda:0', dtype=torch.int32)
causal_nodes_1 = torch.tensor([1305, 5209,  935,   68], device='cuda:0', dtype=torch.int32)

ori_nodes_2 = torch.tensor([41, 1296, 1296, 1296, 1296, 1296, 7988, 1305, 1296, 1296, 1296, 1296,
        1296, 1305, 2353,   68, 1296, 1296, 1296, 1305, 2353, 1305, 1305,   79,
          79], device='cuda:0', dtype=torch.int32)
causal_nodes_2 = torch.tensor([68, 1305, 1305, 1296], device='cuda:0', dtype=torch.int32)


causal_color = '#F08080'
# 转换为列表
ori_nodes_1 = ori_nodes_1.tolist()
causal_nodes_1 = causal_nodes_1.tolist()
ori_nodes_2 = ori_nodes_2.tolist()
causal_nodes_2 = causal_nodes_2.tolist()


# 定义中心节点
center_node1 = ori_nodes_1[0]
other_nodes_1 = ori_nodes_1[1:]
center_node2 = ori_nodes_2[0]
other_nodes_2 = ori_nodes_2[1:]

# 创建多重图
G = nx.MultiGraph()

# 添加节点和边
G.add_node(center_node1)
G.add_nodes_from(other_nodes_1)

G.add_node(center_node2)
G.add_nodes_from(other_nodes_2)

G.add_edges_from([(center_node1, other_node) for other_node in other_nodes_1],color='lightgray',linestyle ='-',width=3) 
G.add_edges_from([(center_node2, other_node) for other_node in other_nodes_2],color='lightgray',linestyle ='-',width=3)

# 将因果边设为红色
G.add_edges_from([(center_node1, other_node) for other_node in causal_nodes_1],color=causal_color,linestyle ='-',width=3) 
G.add_edges_from([(center_node2, other_node) for other_node in causal_nodes_2],color=causal_color,linestyle ='-',width=3)


# 设置图的属性
# 为每个节点指定颜色
node_color = []
for node in G.nodes():

        if node == center_node1 or node == center_node2:
                node_color.append(causal_color)  # 根节点的颜色用绿色
        elif node in causal_nodes_1: # 图一的因果节点用红色
                node_color.append(causal_color)
        elif node in causal_nodes_2: # 图二的因果节点用红色
                node_color.append(causal_color)
        else:
                node_color.append('lightgray') # 两者的非因果节点用灰色

# 获取边颜色
edge_colors = [G[u][v][key]['color'] for u, v, key in G.edges(keys=True)]
# # 获取边类型
edge_styles = [G[u][v][key]['linestyle'] for u, v, key in G.edges(keys=True)]
# 设置边的宽度
edge_widths = [G[u][v][key]['width'] for u, v, key in G.edges(keys=True)]


# 去掉自环
self_loops = list(nx.selfloop_edges(G))
G.remove_edges_from(self_loops)



# 绘制所有边(不包括置顶边)
def draw_curved_edges(G, pos, color='lightgray'):
    for u, v, key in G.edges(keys=True) :
        if (u, v,key) != (center_node1,center_node2,1) and \
        (u,v,key) not in [(center_node1, other_node,1) for other_node in causal_nodes_1] and\
        (u,v,key) not in [(center_node2, other_node,1) for other_node in causal_nodes_2]:  # 不绘制特定的边
                rad = 0.1 * (key + 1)
                edge = patches.FancyArrowPatch(
                        posA=pos[u],
                        posB=pos[v],
                        connectionstyle=f"arc3,rad={rad}",
                        color=color,
                        arrowstyle='->',
                        mutation_scale=10,
                        lw=3,
                        zorder=0  # 设置较低的z-order
                )
                plt.gca().add_patch(edge)

# 绘制置顶边
def draw_highlighted_edge(G, pos, color='red'):
    for u, v, key in G.edges(keys=True) :
        if (u, v,key) != (center_node1,center_node2,1) and ((u,v,key) in [(center_node1, other_node,1) for other_node in causal_nodes_1] or\
        (u,v,key) in [(center_node2, other_node,1) for other_node in causal_nodes_2]):  # 不绘制特定的边
                rad = 0.1 * (key + 1)
                edge = patches.FancyArrowPatch(
                posA=pos[u],
                posB=pos[v],
                connectionstyle=f'arc3,rad={rad}',
                color=color,
                arrowstyle='->',
                mutation_scale=10,
                lw=3,  # 更粗的边线宽度
                zorder=2  # 设置较高的z-order

                )
                plt.gca().add_patch(edge)
        
        
        u,v = center_node1,center_node2
        edge = patches.FancyArrowPatch(
        posA=pos[u],
        posB=pos[v],
        connectionstyle=f'arc3,rad=0.5',
        color=color,
        arrowstyle='->',
        mutation_scale=10,
        lw=3,  # 更粗的边线宽度
        linestyle='--',
        zorder=2  # 设置较高的z-order
        
        )
        plt.gca().add_patch(edge)

# 绘制图
pos = nx.spring_layout(G,seed=666)  # 使用 Fruchterman-Reingold 布局算法

plt.figure(figsize=(12, 12))

nx.draw(G, pos, with_labels=True, node_size=2500, node_color=node_color,  style=edge_styles,\
        edge_color = edge_colors, width=edge_widths, font_size=15, font_weight='bold')

# 绘制弯曲的边
draw_curved_edges(G, pos, color='lightgray')

# 绘制置顶边
draw_highlighted_edge(G, pos, color=causal_color)


# 保存图像
plt.savefig('example.png', format='png', bbox_inches='tight')

  • 3
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值