import torch
from torch_geometric.data import Data
edge_index = torch.tensor([[0, 1, 1, 2],
[1, 0, 2, 1]], dtype=torch.long)
x = torch.tensor([[-1], [0], [1]], dtype=torch.float)
data = Data(x=x, edge_index=edge_index)
print(data)
1、简单的无向图构建代码解释:
edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]], dtype=torch.long)
edge_index表示了一个无向图中的 4 条边,这行代码中的矩阵含义得按列来看:
第一列[0,1]代表节点0和节点1之间有一条边
第二列[1,0]代表节点1和节点0之间有一条边 (无向图)
第三列[1,2]代表节点1和节点2之间有一条边
第四列[2,1]代表节点2和节点1之间有一条边
x = torch.tensor([[-1], [0], [1]], dtype=torch.float)
x代表一个3x1的特征张量,这个矩阵得按行来看,每一行代表一个节点的特征:
第一个节点的特征值为-1
第二个节点的特征值为0
第三个节点的特征值为1
data = Data(x=x, edge_index=edge_index)
这行创建了一个 Data
对象,传入了节点特征张量 x
和边的索引 edge_index
。这样就构建了一个包含节点特征和边信息的图数据对象。
可以得到输出结果,可以看出data包含了关于图结构的基本信息,以及每个节点的特征维度:
Data(edge_index=[2, 4], x=[3, 1])
edge_index=[2, 4]:该图有两条边,每条边由两个节点组成,这个信息与我们之前定义的边索引形状相符,即 2x4 的张量,其中第一行表示起始节点,第二行表示结束节点。
x=[3, 1]:表示该图有三个节点,并且每个节点的特征维度为1
该图表示以上图的结构:
2、如果要确保传入Data中的张量的储存是连续的,那需要用到contiguous()
方法:
假设有一个不连续的张量,我们想要将其传递给一个需要连续张量作为输入的函数。
import torch
# 创建一个不连续的张量
x = torch.tensor([[1, 2, 3], [4, 5, 6]])
print("原始张量:")
print(x)
print("是否连续:", x.is_contiguous())
# 选择子张量,这会导致不连续性
sub_x = x[:, 1:]
print("选择子张量后的张量:")
print(sub_x)
print("是否连续:", sub_x.is_contiguous())
# 尝试传递不连续的张量给函数,会导致错误
try:
torch.sum(sub_x, dim=1)
except RuntimeError as e:
print("错误信息:", e)
# 使用contiguous()方法使其连续
sub_x_contiguous = sub_x.contiguous()
print("使用contiguous()方法后的张量:")
print(sub_x_contiguous)
print("是否连续:", sub_x_contiguous.is_contiguous())
# 现在可以成功地传递连续的张量给函数
sum_result = torch.sum(sub_x_contiguous, dim=1) #dim=0代表一维按列相加求和 ,dim=1代表沿着张量的第二个维度进行求和,也就是按行相加求和
print("计算结果:", sum_result)
可以得到输出如下所示:
可以看出,原本的原始张量的值是从1-6连续的,之后取得子张量后矩阵的值变为了2、3、5、6,不是连续的了,所以需要使用contiguous()方法使张量连续后才能将正确的将张量传递给函数进行sum计算。
原始张量:
tensor([[1, 2, 3],
[4, 5, 6]])
是否连续: True
选择子张量后的张量:
tensor([[2, 3],
[5, 6]])
是否连续: False
使用contiguous()方法后的张量:
tensor([[2, 3],
[5, 6]])
是否连续: True
计算结果: tensor([7, 9])
3、validate()检查数据对象(Data中的属性)
data.validate(raise_on_error=True)
当调用 validate()
方法时,它会检查 Data
对象中的各种属性,如节点特征、边索引等,并根据定义的规则来验证这些属性是否符合要求。如果发现数据不符合要求,它可能会引发异常或者返回 False
。如果数据符合要求,则返回 True
。
使用 validate()
方法可以确保在使用图数据进行训练或其他操作时,数据的一致性和正确性,从而避免潜在的错误。
4、Data 对象在PyTorch Geometric 中的作用
它不仅用于存储图数据中的节点级、边级或图级属性,还提供了一些有用的工具函数,用于对图数据进行处理、操作和分析。
import torch
from torch_geometric.data import Data
edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]], dtype=torch.long)
x = torch.tensor([[-1], [0], [1]], dtype=torch.float)
data = Data(x=x, edge_index=edge_index)
print(data.keys())
# ['edge_index', 'x']
print(data['x'])
# tensor([[-1.],[ 0.],[ 1.]])
for key, item in data:
print(f'{key} found in data')
# edge_index found in data
# x found in data
print('edge_attr' in data)
# False
print(data.num_nodes) # 节点数
# 3
print(data.num_edges) # 边数
# 4
print(data.num_node_features) # 节点特征数
# 1
print(data.has_isolated_nodes()) # 是否有间断点
# False
print(data.is_directed()) # 检查图是否是有向图
5、常见的基准数据集
PyG 包含许多常见的基准数据集,例如所有 Planetoid 数据集(Cora、Citeseer、Pubmed)、来自 TUDatasets 的所有图分类数据集及其清理版本、QM7 和 QM9 数据集,以及一些 3D 网格/点云数据集,如 FAUST、ModelNet10/40 和 ShapeNet。
初始化数据集非常简单。对数据集的初始化将自动下载其原始文件并将其处理成前面描述的 Data 格式。例如,要加载 ENZYMES 数据集(包含 6 个类别的 600 个图),请键入: