PyG学习笔记(1)

本文详细介绍了如何在PyTorchGeometric中使用`Data`对象构建无向图,处理不连续张量的连续性,以及`validate()`方法的作用。同时涵盖了`Data`对象的使用、图的属性检查和常见的基准数据集加载方法。
摘要由CSDN通过智能技术生成
import torch
from torch_geometric.data import Data

edge_index = torch.tensor([[0, 1, 1, 2],
                           [1, 0, 2, 1]], dtype=torch.long)
x = torch.tensor([[-1], [0], [1]], dtype=torch.float)

data = Data(x=x, edge_index=edge_index)
print(data)

1、简单的无向图构建代码解释:

edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]], dtype=torch.long)

edge_index表示了一个无向图中的 4 条边,这行代码中的矩阵含义得按列来看:

第一列[0,1]代表节点0和节点1之间有一条边

第二列[1,0]代表节点1和节点0之间有一条边  (无向图)

第三列[1,2]代表节点1和节点2之间有一条边

第四列[2,1]代表节点2和节点1之间有一条边

x = torch.tensor([[-1], [0], [1]], dtype=torch.float)

x代表一个3x1的特征张量,这个矩阵得按行来看,每一行代表一个节点的特征:

第一个节点的特征值为-1

第二个节点的特征值为0

第三个节点的特征值为1

data = Data(x=x, edge_index=edge_index)

这行创建了一个 Data 对象,传入了节点特征张量 x 和边的索引 edge_index。这样就构建了一个包含节点特征和边信息的图数据对象。

可以得到输出结果,可以看出data包含了关于图结构的基本信息,以及每个节点的特征维度:

 Data(edge_index=[2, 4], x=[3, 1])

edge_index=[2, 4]:该图有两条边,每条边由两个节点组成,这个信息与我们之前定义的边索引形状相符,即 2x4 的张量,其中第一行表示起始节点,第二行表示结束节点。

x=[3, 1]:表示该图有三个节点,并且每个节点的特征维度为1

该图表示以上图的结构:

2、如果要确保传入Data中的张量的储存是连续的,那需要用到contiguous() 方法:

假设有一个不连续的张量,我们想要将其传递给一个需要连续张量作为输入的函数。

import torch

# 创建一个不连续的张量
x = torch.tensor([[1, 2, 3], [4, 5, 6]])
print("原始张量:")
print(x)
print("是否连续:", x.is_contiguous())

# 选择子张量,这会导致不连续性
sub_x = x[:, 1:]
print("选择子张量后的张量:")
print(sub_x)
print("是否连续:", sub_x.is_contiguous())

# 尝试传递不连续的张量给函数,会导致错误
try:
    torch.sum(sub_x, dim=1)
except RuntimeError as e:
    print("错误信息:", e)

# 使用contiguous()方法使其连续
sub_x_contiguous = sub_x.contiguous()
print("使用contiguous()方法后的张量:")
print(sub_x_contiguous)
print("是否连续:", sub_x_contiguous.is_contiguous())

# 现在可以成功地传递连续的张量给函数
sum_result = torch.sum(sub_x_contiguous, dim=1) #dim=0代表一维按列相加求和 ,dim=1代表沿着张量的第二个维度进行求和,也就是按行相加求和
print("计算结果:", sum_result)

可以得到输出如下所示:

可以看出,原本的原始张量的值是从1-6连续的,之后取得子张量后矩阵的值变为了2、3、5、6,不是连续的了,所以需要使用contiguous()方法使张量连续后才能将正确的将张量传递给函数进行sum计算。

原始张量:
tensor([[1, 2, 3],
        [4, 5, 6]])
是否连续: True
选择子张量后的张量:
tensor([[2, 3],
        [5, 6]])
是否连续: False
使用contiguous()方法后的张量:
tensor([[2, 3],
        [5, 6]])
是否连续: True
计算结果: tensor([7, 9])

3、validate()检查数据对象(Data中的属性)

data.validate(raise_on_error=True)

        当调用 validate() 方法时,它会检查 Data 对象中的各种属性,如节点特征、边索引等,并根据定义的规则来验证这些属性是否符合要求。如果发现数据不符合要求,它可能会引发异常或者返回 False。如果数据符合要求,则返回 True

使用 validate() 方法可以确保在使用图数据进行训练或其他操作时,数据的一致性和正确性,从而避免潜在的错误。

4、Data 对象在PyTorch Geometric 中的作用

        它不仅用于存储图数据中的节点级、边级或图级属性,还提供了一些有用的工具函数,用于对图数据进行处理、操作和分析。

import torch
from torch_geometric.data import Data

edge_index = torch.tensor([[0, 1, 1, 2], [1, 0, 2, 1]], dtype=torch.long)
x = torch.tensor([[-1], [0], [1]], dtype=torch.float)
data = Data(x=x, edge_index=edge_index)
print(data.keys())
# ['edge_index', 'x']
print(data['x'])
# tensor([[-1.],[ 0.],[ 1.]])
for key, item in data:
    print(f'{key} found in data')
# edge_index found in data
# x found in data
print('edge_attr' in data)
# False
print(data.num_nodes)  # 节点数
# 3
print(data.num_edges)  # 边数
# 4
print(data.num_node_features)  # 节点特征数
# 1
print(data.has_isolated_nodes())  # 是否有间断点
# False
print(data.is_directed())  # 检查图是否是有向图

5、常见的基准数据集

        PyG 包含许多常见的基准数据集,例如所有 Planetoid 数据集(Cora、Citeseer、Pubmed)、来自 TUDatasets 的所有图分类数据集及其清理版本、QM7 和 QM9 数据集,以及一些 3D 网格/点云数据集,如 FAUST、ModelNet10/40 和 ShapeNet。

初始化数据集非常简单。对数据集的初始化将自动下载其原始文件并将其处理成前面描述的 Data 格式。例如,要加载 ENZYMES 数据集(包含 6 个类别的 600 个图),请键入:

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值