基本图变量
PyG中的基本图变量为torch_geometric.data.Data,具有以下参数:
data.x: Node feature matrix with shape [num_nodes, num_node_features]
data.edge_index: Graph connectivity in COO format with shape [2, num_edges] and type torch.long
data.edge_attr: Edge feature matrix with shape [num_edges, num_edge_features]
data.y: Target to train against (may have arbitrary shape), e.g., node-level targets of shape [num_nodes, *] or graph-level targets of shape [1, *]
data.pos: Node position matrix with shape [num_nodes, num_dimensions]
建立一个图,如下:
import torch
from torch_geometric.data import Data
#edge_index定义了所有图中边的起点终点,在无向图中,0-1和1-0是一条边,但是在PyG中认为是两条不同方向边,这里需要都定义出来,代表0-1这条边是双向的,1-2同理。
#简单理解就是PyG中默认的图的边是单向的,所以如果要构建双向/无向边就需要起点终点定义两次
#type选择long是PyG中边的默认格式
edge_index = torch.tensor([[0, 1], [1, 0], [1, 2], [2, 1]], dtype=torch.long)
x = torch.tensor([[-1], [0], [1]], dtype=torch.float)
data = Data(x=x, edge_index=edge_index.t().contiguous())
#data的参数如下:edge_index代表边参数[2,4]中2是包中预设值,4为图中PyG定义下的边数量。
#x代表图顶点参数,[3,1]中3代表有3个顶点,1代表顶点的特征数为1
print(data)
>>> Data(edge_index=[2, 4], x=[3, 1])
print(data.keys)
>>>