最近在学习图神经网络相关的知识,发现主要使用的库有两个:dgl和pyg,但是pyg的资料是英文的,而且网上资料也比较少,就翻译一下分享。
图数据处理
单张图被PyG表示为torch_geometric.data.Data
类型,有如下属性:
data.x
: 节点的特征矩阵,形状为 [num_nodes, num_node_features]
data.edge_index
: COO格式的图的边 shape [2, num_edges] and type torch.long
data.edge_attr
:边的特征矩阵 shape [num_edges, num_edge_features]
data.y
: 训练数据的标签,节点级的目标 shape [num_nodes, *] or 图级的目标 shape [1, *]
data.pos
: 节点的位置矩阵 shape [num_nodes, num_dimensions]
import torch
from torch_geometric.data import Data
# 一个无权无向图的例子
# 边的索引 COO格式 就是说第一行代表行索引 第二行代表列索引
# 节点0 1之间有一条边 12之间有一条边
edge_index = torch.tensor([[0, 1, 1, 2],
[1, 0, 2, 1]], dtype=torch.long)
# 节点0 1 2的特征分别是
x = torch.tensor([[-1], [0], [1]], dtype=torch.float)
data = Data(x=x, edge_index=edge_index)
# 不是COO格式的边 是使用边的两个节点的元组形式
import torch
from torch_geometric.data import Data
edge_index = torch.tensor([[0, 1],
[1, 0],
[1, 2],
[2, 1]], dtype=torch.long)
x = torch.tensor([[-1], [0], [1]], dtype=torch.float)
data = Data(x=x, edge_index=edge_index.t().contiguous())
对于data提供了如下方法进行访问:
# 可以使用类似字典的方式
print(data.keys)
>>> ['x', 'edge_index'] # 前面带箭头的表示输出
print(data['x'])
>>> tensor([[-1.0],
[0.0],
[1.0]])
for key, item in data:
print(f'{
key} found in data')
>>> x found in data
>>> edge_index found in data
'edge_attr' in data
>>> False