代码
import torch
import numpy as np
import networkx as nx
import matplotlib.pyplot as plt
def graph_showing(data):
'''
args:
data: torch_geometric.data.Data
'''
G = nx.Graph()
edge_index = data['edge_index'].t()
# print(edge_index)
edge_index = np.array(edge_index.cpu())
# print(edge_index)
G.add_edges_from(edge_index)
nx.draw(G)
plt.show()
展示
展示用到的是TUDataset数据集
from torch_geometric.datasets import TUDataset
dataset = TUDataset(root='./data/ENZYMES', name='ENZYMES')
for data in dataset:
print(data)
# print(data['edge_index'])
# print(data.is_directed())
# print(data.num_edges)
graph_showing(data)
输出: