networkx 标签_在PyTorch框架下使用PyG和networkx对Graph进行可视化

09f313c69b7cfdf9c921eaedad6985c3.png

简介:本文介绍如何将Pytorch Geometric运行过程中得到的data或者tensor转换成networkx可以处理的格式,进行可视化。

其中Pytorch geometric的地址为

PyTorch Geometric Documentation​pytorch-geometric.readthedocs.io

方法一

根据networkx的文档

draw_networkx - NetworkX 1.10 documentation​networkx.github.io

我们可以写出来一个非常简单的例子,如下:

import 

运行程序之后,可以得到下面的图,(偷了一个懒,没有加label之类的信息)

f8678acd15476da6e5106b34ab3a4b67.png

这个例子给我们的启发就是,我们可以将PyG得到的edge_index转成numpy的格式,然后传给nx,下面是根据这个写的一个函数:

在PyG中,边的表示放在了edge_index中,由一个二维的矩阵构成,edge_index[0]表示节点edge_index[1]表示另一个节点。
def draw(edge_index, name=None):
    G = nx.Graph(node_size=15, font_size=8)
    src = edge_index[0].cpu().numpy()
    dst = edge_index[1].cpu().numpy()
    edgelist = zip(src, dst)
    for i, j in edgelist:
        G.add_edge(i, j)
    plt.figure(figsize=(20, 14)) # 设置画布的大小
    nx.draw_networkx(G)
    plt.savefig('{}.png'.format(name if name else 'path'))

注:该方法可以用于模型中的forward函数,用于分析cov,pool等操作

下面是与上面思想一致可以直接运行的一个例子

from torch_geometric.datasets import KarateClub
import networkx as nx
import matplotlib.pyplot as plt
dataset = KarateClub()
edge, x, y = dataset[0]
# edge, x, y 每个维度都为2,其中第一维度是name,第二个维度是data
# x表示的是结点,y表示的标签,edge表示的连边, 由两个维度的tensor构成
x_np = x[1].numpy()
y_np = y[1].numpy()
g = nx.Graph()
name, edgeinfo = edge
src = edgeinfo[0].numpy()
dst = edgeinfo[1].numpy()
edgelist = zip(src, dst)
for i, j in edgelist:
    g.add_edge(i, j)
nx.draw(g)
plt.savefig('test.png')
plt.show()

方法二

其实,torch_geometric.utils中已经带有to_networkx的函数可以直接将格式为torch_geometric.data.Data 的数据转换为networkx.DiGraph的格式,该格式可以直接networkx处理,但是我们提前要得到torch_geometric.data.Data的数据格式

import networkx as nx
from torch_geometric.utils.convert import to_networkx
def draw(Data):
    G = to_networkx(Data)
    nx.draw(G)
    plt.savefig("path.png")
    plt.show()

这个一般可以用于在model加载数据之前数据的分析,比如下面的例子

 for i, data in enumerate(train_loader):
        draw(data)
        data = data.to(args.device)
        out = model(data)
        loss = F.nll_loss(out, data.y)
        print("Training loss:{}".format(loss.item()))
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

上面的函数是在graph classification进行分析的一段代码,可以把batch size的设置为1,那么for循环中得到就是一个graph的数据,在把数据feed给模型之前,我们可以通过该方法分析一下原始的数据是什么样子的。

如果有什么问题,或者获取完整代码(其实上面已经很完整了),请联系

5b58ba737798ca6f6e1264df158d79cf.png
  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值