在上一篇的介绍中,主要讲了Pytorch-Geometric
的五个基础用例,但是其中存在一些问题还没有解决,下面开始一一解决,本文的重点是如何手动加载PyG的数据集。
1.关于创建Data实例时,维度异常的问题
问题描述:
在Data
创建过程中,edge_index
表示边的信息,x
为节点的特征向量,y
为目标值,如果y
的维度([num_nodes, *]
)和节点总数的维度是一样的,那就是node-level
;如果y
的维度是[1,*]
,那就是graph-level
。但是如果y
的维度不符合上述的两种情况,在创建过程会如何?
解决方案:
### Q1: X维度和Y的维度不统一
import torch
from torch_geometric.data import Data
# 构建边
edge_index = torch.tensor([
[3, 1, 1, 2],
[1, 3, 2, 1]], dtype=torch.long)
# 构建X
x = torch.tensor([[-1],
[0],
[1],[2]], dtype=torch.float)
y = torch.tensor([[1], [2], [3]], dtype=torch.float)
data = Data(x=x, y=y, edge_index=edge_index)
print(data)
如代码所示,其中节点共有3个,但是我创建了4个节点的特征向量和5个目标值,运行代码后没有出现错误,所以可以得知Data
实例化的过程中,是不会检查数据是否合理的,只是单纯的构建了一个复杂数据类型而已。
2.如何加载自己下载的数据集
问题描述:
在使用Dataset
进行数据集的创建时,经常会出现HttpError
这种样子的错误,所以手动下载数据集之后,再利用PyG
的函数进行构建,但是这个方式目前还没有找到官方的接口,所以要从源码的角度来处理。
解决方案:
这里用Cora数据集进行实验,在planetoid.py
文件中可以看到代码的下载地址为:
url = 'https://github.com/kimiyoung/planetoid/raw/master/data'
该文件内定义了一个类:
class Planetoid(InMemoryDataset):
里面有有一个download
函数用于下载数据集:
def download(self):
for name in self.raw_file_names:
download_url('{}/{}'.format(self.url, name), self.raw_dir)
调用函数download_url
指定下载地址和下载后存放的目录,其中下载的列表为:
@property
def raw_file_names(self):
names = ['x', 'tx', 'allx', 'y', 'ty', 'ally', 'graph', 'test.index']
return ['ind.{}.{}'.format(self.name.lower(), name) for name in names]
其中一共有八个文件。除了download
函数还有一个process
函数:
def process(self):
data = read_planetoid_data(self.raw_dir, self.name)
data = data if self.pre_transform is None else self.pre_transform(data)
torch.save(self.collate([data]), self.processed_paths[0])
第一行代码中的read_planetoid_data
函数是进行数据的加载并且切分出训练集、验证集、测试集,构造为Data实例:
data = Data(x=x, edge_index=edge_index, y=y)
data.train_mask = train_mask
data.val_mask = val_mask
data.test_mask = test_mask
第二行代码表示是否进行pre_transform
操作(也就是是否进行一次数据转换,一般3D点云数据比较常见);第三行利用torch.save
进行本地序列化。经过上面的一段分析,可以确定的是首先进行数据集下载,然后进行处理最后保存到本地一个新的序列化文件,所以只需要在下载过程跳过即可,但是考虑到这么一点,在之前学习的过程中,可以发现,当你第一次创建完数据集(下载到本地)之后,第二次时间比较短,所以一定存在防覆盖机制
来优化程序运行速度,于是在download_url
函数中找到了这一块代码:
if osp.exists(path): # pragma: no cover
if log:
print('Using exist file', filename)
return path
所以总结一下:
(1)根据URL下载自己的数据集;
(2)放到本地文件夹中,格式为:
其中Cora
文件夹是根目录,processed
是处理后torch.save
的,所以自己下载的数据放在raw文件夹中;
(3)调用Dataset
中的接口创建数据集即可。
from torch_geometric.datasets import Planetoid
dataset = Planetoid(root='data/', name='Cora')
print(dataset)
输出信息为:
Processing...
Done!
Cora()
之前GCNModel
的acc
为:
Accuracy: 0.8080