可视化 DGL 图神经网络, 显示节点值和边的值

dgl的网络转为networkx后,边的标签无法显示,顶点权重也不能显示。下面的代码可以正常湿示顶点标签,权重和边的权重。

在学习图神经时,需要了解消息函数,聚合函数或更新函数的意义及查看结果,有一个可视化的图将是非常有意义的,下面的代码可以帮你实现dgl图的可视化。

import dgl
import torch as th
import networkx as nx
import matplotlib.pyplot as plt

u= th.tensor([0,1,2,3])
v=th.tensor([2,0,1,5])
g=dgl.graph((u,v))
print(g.edges())

g.ndata["desc"]=th.tensor([1,2,3,4,5,6])
g.edata["weights"]=th.tensor([4,5,6,7])

plt.figure(figsize=(8, 8))

#desc, weights
def ShowGraph(graph,nodeLabel,EdgeLabel):
    plt.figure(figsize=(8, 8))
    G=graph.to_networkx(node_attrs=nodeLabel.split(),edge_attrs=EdgeLabel.split())  #转换 dgl graph to networks
    pos=nx.spring_layout(G)
    nx.draw(G, pos,edge_color="grey", node_size=500,with_labels=True) # 画图,设置节点大小
    node_data = nx.get_node_attributes(G, nodeLabel)  # 获取节点的desc属性
    node_labels = { index:"N:"+ str(data)  for index,data in enumerate(node_data) }  #重新组合数据, 节点标签是dict, {nodeid:value,nodeid2,value2} 这样的形式
    pos_higher = {}
    
    for k, v in pos.items():  #调整下顶点属性显示的位置,不要跟顶点的序号重复了
        if(v[1]>0):
            pos_higher[k] = (v[0]-0.04, v[1]+0.04)
        else:
            pos_higher[k] = (v[0]-0.04, v[1]-0.04)
    nx.draw_networkx_labels(G, pos_higher, labels=node_labels,font_color="brown", font_size=12)  # 将desc属性,显示在节点上
    edge_labels = nx.get_edge_attributes(G, EdgeLabel) # 获取边的weights属性,
    

    edge_labels= {  (key[0],key[1]): "w:"+str(edge_labels[key].item())  for key in edge_labels } #重新组合数据, 边的标签是dict, {(nodeid1,nodeid2):value,...} 这样的形式
    nx.draw_networkx_edges(G,pos, alpha=0.5 )
    nx.draw_networkx_edge_labels(G, pos, edge_labels=edge_labels,font_size=12) # 将Weights属性,显示在边上

    print(G.edges.data())
    plt.show()



ShowGraph(g,"desc",'weights')

 

评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值