dgl的网络转为networkx后,边的标签无法显示,顶点权重也不能显示。下面的代码可以正常湿示顶点标签,权重和边的权重。
在学习图神经时,需要了解消息函数,聚合函数或更新函数的意义及查看结果,有一个可视化的图将是非常有意义的,下面的代码可以帮你实现dgl图的可视化。
import dgl
import torch as th
import networkx as nx
import matplotlib.pyplot as plt
u= th.tensor([0,1,2,3])
v=th.tensor([2,0,1,5])
g=dgl.graph((u,v))
print(g.edges())
g.ndata["desc"]=th.tensor([1,2,3,4,5,6])
g.edata["weights"]=th.tensor([4,5,6,7])
plt.figure(figsize=(8, 8))
#desc, weights
def ShowGraph(graph,nodeLabel,EdgeLabel):
plt.figure(figsize=(8, 8))
G=graph.to_networkx(node_attrs=nodeLabel.split(),edge_attrs=EdgeLabel.split()) #转换 dgl graph to networks
pos=nx.spring_layout(G)
nx.draw(G, pos,edge_color="grey", node_size=500,with_labels=True) # 画图,设置节点大小
node_data = nx.get_node_attributes(G, nodeLabel) # 获取节点的desc属性
node_labels = { index:"N:"+ str(data) for index,data in enumerate(node_data) } #重新组合数据, 节点标签是dict, {nodeid:value,nodeid2,value2} 这样的形式
pos_higher = {}
for k, v in pos.items(): #调整下顶点属性显示的位置,不要跟顶点的序号重复了
if(v[1]>0):
pos_higher[k] = (v[0]-0.04, v[1]+0.04)
else:
pos_higher[k] = (v[0]-0.04, v[1]-0.04)
nx.draw_networkx_labels(G, pos_higher, labels=node_labels,font_color="brown", font_size=12) # 将desc属性,显示在节点上
edge_labels = nx.get_edge_attributes(G, EdgeLabel) # 获取边的weights属性,
edge_labels= { (key[0],key[1]): "w:"+str(edge_labels[key].item()) for key in edge_labels } #重新组合数据, 边的标签是dict, {(nodeid1,nodeid2):value,...} 这样的形式
nx.draw_networkx_edges(G,pos, alpha=0.5 )
nx.draw_networkx_edge_labels(G, pos, edge_labels=edge_labels,font_size=12) # 将Weights属性,显示在边上
print(G.edges.data())
plt.show()
ShowGraph(g,"desc",'weights')