Pytorch Geometric 3 - 自定义数据集

原文地址:INTRODUCTION BY EXAMPLE

博客地址:陈小默的CSDN

PyTorch Geometric 为数据集提供了两个抽象类torch_geometric.data.Datasettorch_geometric.data.InMemoryDataset

其中InMemoryDataset继承自Dataset,如果要使用InMemoryDataset则需要使数据集大小适合存放在内存中。

首先需要一个保存有数据文件的文件夹root,该文件夹将被划分为两个文件夹,一个用于存储数据集的文件夹raw_dir和一个用来保存处理后数据集的文件夹processed_dir

除了root,类初始化的init函数还接收三个函数参数transform, pre_transformpre_filter,这些参数的默认值都是None。transform函数用于动态的转换数据对象。pre_transform函数在数据保存到硬盘之前进行一次转换。pre_filter用于过滤某些数据对象。

保存在内存中的数据集

为了创建InMemoryDataset,需要实现下面四个方法:

  • raw_file_names():该函数返回的文件名需要在raw_dir文件夹下找到才可以跳过下载过程。
  • processed_file_names():该函数放回的文件名需要在processed_dir中找到才可以跳过处理过程。
  • download():下载文件到 raw_dir
  • process():处理原始数据并保存在processed_dir

process()函数中,我们需要读入并创建一个Data对象列表之后将所有Data类型的对象保存在processed_dir文件夹中。由于无法将全部数据保存到内存中,需要在数据固化之前通过collate()函数保存Data对象的索引,此外,该函数还会返回一个slices字典用于从本地重建单个样例对象。于是在数据集对象new的时候,需要从本地读取self.dataself.slices对象。

import torch
from torch_geometric.data import InMemoryDataset


class MyOwnDataset(InMemoryDataset):
    def __init__(self, root, transform=None, pre_transform=None):
        super(MyOwnDataset, self).__init__(root, transform, pre_transform)
        self.data, self.slices = torch.load(self.processed_paths[0])

    @property
    def raw_file_names(self):
        return ['some_file_1', 'some_file_2', ...]

    @property
    def processed_file_names(self):
        return ['data.pt']

    def download(self):
        # Download to `self.raw_dir`.

    def process(self):
        # Read data into huge `Data` list.
        data_list = [...]

        if self.pre_filter is not None:
            data_list = [data for data in data_list if self.pre_filter(data)]

        if self.pre_transform is not None:
            data_list = [self.pre_transform(data) for data in data_list]

        data, slices = self.collate(data_list)
        torch.save((data, slices), self.processed_paths[0])

创建更大规模的数据集

有一些数据的规模太大,无法一次性加载到内存中,那么我们需要自己实现torch_geometric.data.Dataset,只需要额外实现两个方法:

  • len(): 返回数据集的长度
  • get():自定义加载Graph的方法
import os.path as osp

import torch
from torch_geometric.data import Dataset


class MyOwnDataset(Dataset):
    def __init__(self, root, transform=None, pre_transform=None):
        super(MyOwnDataset, self).__init__(root, transform, pre_transform)

    @property
    def raw_file_names(self):
        return ['some_file_1', 'some_file_2', ...]

    @property
    def processed_file_names(self):
        return ['data_1.pt', 'data_2.pt', ...]

    def download(self):
        # Download to `self.raw_dir`.

    def process(self):
        i = 0
        for raw_path in self.raw_paths:
            # Read data from `raw_path`.
            data = Data(...)

            if self.pre_filter is not None and not self.pre_filter(data):
                continue

            if self.pre_transform is not None:
                data = self.pre_transform(data)

            torch.save(data, osp.join(self.processed_dir, 'data_{}.pt'.format(i)))
            i += 1

    def len(self):
        return len(self.processed_file_names)

    def get(self, idx):
        data = torch.load(osp.join(self.processed_dir, 'data_{}.pt'.format(idx)))
        return data

常见问题

  1. 如何跳过 download()process() 过程?

    可以通过忽略(不重写)该函数的方法实现。

    class MyOwnDataset(Dataset):
        def __init__(self, transform=None, pre_transform=None):
            super(MyOwnDataset, self).__init__(None, transform, pre_transform)
    
  2. 以上这些函数是否都是必须要使用到的?

    对于动态创建的数据集而言,保存到内存中是非必要的,甚至在某些特定情况下,可以直接使用list充当数据集,如下:

    from torch_geometric.data import Data, DataLoader
    
    data_list = [Data(...), ..., Data(...)]
    loader = DataLoader(data_list, batch_size=32)
    
  • 7
    点赞
  • 17
    收藏
    觉得还不错? 一键收藏
  • 15
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 15
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值