博客地址:陈小默的CSDN
PyTorch Geometric 为数据集提供了两个抽象类torch_geometric.data.Dataset
和torch_geometric.data.InMemoryDataset
。
其中InMemoryDataset
继承自Dataset
,如果要使用InMemoryDataset
则需要使数据集大小适合存放在内存中。
首先需要一个保存有数据文件的文件夹root
,该文件夹将被划分为两个文件夹,一个用于存储数据集的文件夹raw_dir
和一个用来保存处理后数据集的文件夹processed_dir
。
除了root
,类初始化的init函数还接收三个函数参数transform
, pre_transform
和pre_filter
,这些参数的默认值都是None。transform
函数用于动态的转换数据对象。pre_transform
函数在数据保存到硬盘之前进行一次转换。pre_filter
用于过滤某些数据对象。
保存在内存中的数据集
为了创建InMemoryDataset
,需要实现下面四个方法:
raw_file_names()
:该函数返回的文件名需要在raw_dir
文件夹下找到才可以跳过下载过程。processed_file_names()
:该函数放回的文件名需要在processed_dir
中找到才可以跳过处理过程。download()
:下载文件到raw_dir
process()
:处理原始数据并保存在processed_dir
在process()
函数中,我们需要读入并创建一个Data对象列表之后将所有Data类型的对象保存在processed_dir
文件夹中。由于无法将全部数据保存到内存中,需要在数据固化之前通过collate()
函数保存Data对象的索引,此外,该函数还会返回一个slices
字典用于从本地重建单个样例对象。于是在数据集对象new的时候,需要从本地读取self.data
和 self.slices
对象。
import torch
from torch_geometric.data import InMemoryDataset
class MyOwnDataset(InMemoryDataset):
def __init__(self, root, transform=None, pre_transform=None):
super(MyOwnDataset, self).__init__(root, transform, pre_transform)
self.data, self.slices = torch.load(self.processed_paths[0])
@property
def raw_file_names(self):
return ['some_file_1', 'some_file_2', ...]
@property
def processed_file_names(self):
return ['data.pt']
def download(self):
# Download to `self.raw_dir`.
def process(self):
# Read data into huge `Data` list.
data_list = [...]
if self.pre_filter is not None:
data_list = [data for data in data_list if self.pre_filter(data)]
if self.pre_transform is not None:
data_list = [self.pre_transform(data) for data in data_list]
data, slices = self.collate(data_list)
torch.save((data, slices), self.processed_paths[0])
创建更大规模的数据集
有一些数据的规模太大,无法一次性加载到内存中,那么我们需要自己实现torch_geometric.data.Dataset
,只需要额外实现两个方法:
len()
: 返回数据集的长度get()
:自定义加载Graph的方法
import os.path as osp
import torch
from torch_geometric.data import Dataset
class MyOwnDataset(Dataset):
def __init__(self, root, transform=None, pre_transform=None):
super(MyOwnDataset, self).__init__(root, transform, pre_transform)
@property
def raw_file_names(self):
return ['some_file_1', 'some_file_2', ...]
@property
def processed_file_names(self):
return ['data_1.pt', 'data_2.pt', ...]
def download(self):
# Download to `self.raw_dir`.
def process(self):
i = 0
for raw_path in self.raw_paths:
# Read data from `raw_path`.
data = Data(...)
if self.pre_filter is not None and not self.pre_filter(data):
continue
if self.pre_transform is not None:
data = self.pre_transform(data)
torch.save(data, osp.join(self.processed_dir, 'data_{}.pt'.format(i)))
i += 1
def len(self):
return len(self.processed_file_names)
def get(self, idx):
data = torch.load(osp.join(self.processed_dir, 'data_{}.pt'.format(idx)))
return data
常见问题
-
如何跳过
download()
和process()
过程?可以通过忽略(不重写)该函数的方法实现。
class MyOwnDataset(Dataset): def __init__(self, transform=None, pre_transform=None): super(MyOwnDataset, self).__init__(None, transform, pre_transform)
-
以上这些函数是否都是必须要使用到的?
对于动态创建的数据集而言,保存到内存中是非必要的,甚至在某些特定情况下,可以直接使用list充当数据集,如下:
from torch_geometric.data import Data, DataLoader data_list = [Data(...), ..., Data(...)] loader = DataLoader(data_list, batch_size=32)