CREATING YOUR OWN DATASETS

即使PyG已经包含了很多数据集,但是如果大家想使用自己的或非公开数据集,还是需要实现自己的 dataset。对于数据集的创建涉及到两个类, torch_geometric.data.Dataset 和 torch_geometric.data.InMemoryDataset

,其中第二个是第一个的子类,如果希望全部数据都在内存里则需要使用第二个类。每个数据集需要提供文件夹路径作为参数,其中一个 raw_dir存储数据集的源文件,而另一个参数 processed_dir存储处理过的文件。

每个数据集都会经过 transform,pre_transform,pre_filter三个函数,默认是 None。第一个函数在使用前动态的转化数据对象(所以最好用于数据增强);第二个函数是将数据集存储在磁盘前的转换函数(最好用于仅需做一次的大量预计算任务);最后一个函数在存储前过滤一些对象。


创建"InMemoryDataset"

为了创建这个数据集,需要实现下面四个基本方法:

raw_file_names:raw_dir文件列表,如果源文件在这里存在的话,就可以跳过下载。

processed_file_names:在 processed_dir里的文件列表,用于跳过处理。

download:将源文件下载到 raw_dir里面。

process:处理源数据并保存到 processed_dir里。在这里面,需要读取并创建一个 Data对象列表存储到上面的文件夹里,但是python存储是慢的,因此我们在存储前通过 collate将list合为一个大的 Data对象,然后从这个对象返回一个 slices字典用于重构单个样例。最后我们需要加载两个对象 self.data, self.slices。

对于其他更高级的方法参考torch_geometric.data

import torch
from torch_geometric.data import InMemoryDataset, download_url


class MyOwnDataset(InMemoryDataset):
    def __init__(self, root, transform=None, pre_transform=None):
        super().__init__(root, transform, pre_transform)
	# 初始化时将数据读入内存
        self.data, self.slices = torch.load(self.processed_paths[0])

    @property
    def raw_file_names(self):
        return ['some_file_1', 'some_file_2', ...]

    @property
    def processed_file_names(self):
        return ['data.pt']

    def download(self):
        # 将源文件下载到`self.raw_dir`.
        download_url(url, self.raw_dir)
        ...

    def process(self):
        # 读数据到大的 `Data` 列表.
        data_list = [...]

        if self.pre_filter is not None: # 首先过滤
            data_list = [data for data in data_list if self.pre_filter(data)]

        if self.pre_transform is not None: 
            data_list = [self.pre_transform(data) for data in data_list]

        data, slices = self.collate(data_list)
	# 将处理后的数据存储
        torch.save((data, slices), self.processed_paths[0])

Creating “Larger” Datasets

要创建无法存储在内存的数据集,可以使用 torch_geometric.data.Dataset,它紧跟 torchvision 数据集的概念。 它期望另外实现以下方法:

torch_geometric.data.Dataset.len():返回数据集中示例的数量。

torch_geometric.data.Dataset.get():实现加载单个图形的逻辑。

在内部,torch_geometric.data.Dataset.__getitem__() 从 torch_geometric.data.Dataset.get() 获取数据对象,并可选择根据Transform对其进行变换。

import os.path as osp

import torch
from torch_geometric.data import Dataset, download_url


class MyOwnDataset(Dataset):
    def __init__(self, root, transform=None, pre_transform=None):
        super().__init__(root, transform, pre_transform)

    @property
    def raw_file_names(self):
        return ['some_file_1', 'some_file_2', ...]

    @property
    def processed_file_names(self):
        return ['data_1.pt', 'data_2.pt', ...]

    def download(self):
        # Download to `self.raw_dir`.
        path = download_url(url, self.raw_dir)
        ...
'''----------------------前面都是一样的---------------------'''
    def process(self): 
	# 这个函数是因为数据比较多,无法一次读入内存,所以以图为单位分开读取、处理、再存储
        idx = 0
        for raw_path in self.raw_paths:
            # 从 `raw_path`读取数据.
            data = Data(...)

            if self.pre_filter is not None and not self.pre_filter(data):
                continue

            if self.pre_transform is not None:
                data = self.pre_transform(data)

            torch.save(data, osp.join(self.processed_dir, f'data_{idx}.pt'))
            idx += 1


    def len(self):
        return len(self.processed_file_names)

    def get(self, idx):
        data = torch.load(osp.join(self.processed_dir, f'data_{idx}.pt'))
        return data

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值