引言
在pyg的torch_geometric.datasets的包中,已经包含许多常见的数据集,但是针对的自己的需求去构建或者引用其他的一些数据集的时候,我们需要在pyg提供的函数的基础上进行数据的规范化。
在pyg中,可以构建两种类型的数据集,一种是In Memory Dataset,另一种是Larger Dataset。前者需要引入的包torch_geometric.data.InMemoryDataset
,适用于小数据集,直接全部加载至内存;后者需要引入torch_geometric.data.Dataset
,适用于分批大数据。需要注意的是前者是继承自后者的。
在创建数据集之前,有几点需要注意
- 每个数据集的根目录分成raw_dir和processed_dir,前者是下载的原始文件需要存储的地方;后者是处理后的数据存储的地方。
- 每个数据集可以经过
transform
,a pre_transform
和a pre_filter
函数,默认是None,这个在介绍之前例子的时候说过了,为了方便阅读,这里重述一遍。
创建"In Memory Datasets"
四个基本函数
- torch_geometric.data.InMemoryDataset.raw_file_names(): 返回一个文件列表,包含raw_dir中的文件目录。可以根据此列表来决定哪些需要下载或者已下载的直接跳过。
- torch_geometric.data.InMemoryDataset.processed_file_names():
返回一个处理后的文件列表,包含processed_dir中的文件目录。据此来决定需要跳过。也就说,在你处理完后,你再次运行该程序将不会二次处理。 - torch_geometric.data.InMemoryDataset.download():
将原始数据下载到 raw_dir 文件夹. 4. torch_geometric.data.InMemoryDataset.process():
处理原始数据将结果存放至 processed_dir 文件夹. 注意,这里需要将结果存储成Data格式。为解决python处理达标存储慢的的问题,通过torch_geometric.data.InMemoryDataset.collate()
将许多Data列表整理成一个很大的Data对象,并且返回一个slices索引字典,因此我们需要设置self.data
和self.slice
这两个属性。
简单数据集搭建
具体代码如下
import torch
from torch_geometric.data import InMemoryDataset
from torch_geometric.data import download_url
import os
from torch_geometric.io import read_planetoid_data
from torch_geometric.datasets import Planetoid
# data=Planetoid(name='Cora',root='data')
class SimpleExample(InMemoryDataset):
#这里参考InMemoryDataset类,这里transform和filter都没用到
def __init__(self,url= 'https://github.com/kimiyoung/planetoid/raw/master/data', dataname='cora',root='dataset', transform=None, pre_transform=None,pre_filter=None):
self.url=url
self.dataname=dataname
self.transform=transform
self.pre_filter=pre_filter
self.pre_transform=pre_transform
self.raw=os.path.join(root,dataname,'raw')
self.processed=os.path.join(root,dataname,'processed')
super(SimpleExample,self).__init__(root=root,transform=transform,pre_transform=pre_transform,pre_filter=pre_filter)
#其中processed_paths来自于Dataset类,返回数据
self.x, self.slices = torch.load(self.processed_paths[0])
#接下来写好四个函数,其中前两个是属性获取,所以这里采用property修饰器
#返回原始文件列表
@property
def raw_file_names(self):
names = ['x', 'tx', 'allx', 'y', 'ty', 'ally', 'graph', 'test.index']
return ['ind.{}.{}'.format(self.dataname.lower(), name) for name in names]
#返回需要跳过的文件列表
@property
def processed_file_names(self):
return ['data.pt']
#下载原生文件
def download(self):
for name in self.raw_file_names:
download_url('{}/{}'.format(self.url, name), self.raw)
def process(self):
data=read_planetoid_data(self.raw,self.dataname)
data_list = [data]
if self.pre_filter is not None:
data_list = [data for data in data_list if self.pre_filter(data)]
if self.pre_transform is not None:
data_list = [self.pre_transform(data) for data in data_list]
data, slices = self.collate(data_list)
torch.save((data, slices), self.processed_paths[0])
#显示属性
def __repr__(self):
return '{}()'.format(self.dataname)
data=SimpleExample()
print(data[0])
print(data.processed_file_names)
具体的执行步骤如下图红框所示,在Dataset
类中可以找到
由此可见,程序默认先执行下载操作,然后再执行处理操作。
创建大数据集
三点注意
在创建大数据集的时候,需要注意的节点是 1. 继承自Dataset类 2. 自己实现len()方法,该方法返回你数据集的长度 3. 自己实现get()方法,实现载入单个图的逻辑。
简单数据集搭建
因为暂时用不到大的数据集,所以暂时先不处理,要是有兴趣,可以参考下面的代码
https://pytorch-geometric.readthedocs.io/en/latest/_modules/torch_geometric/datasets/yelp.html#Yelp
最后,有兴趣的朋友,可以加入群:777486287,方便大家交流探讨