【PyG】创建自己的数据集-图神经网络

即使PyG已经包含了很多数据集,但是如果大家想使用自己的或非公开数据集,还是需要实现自己的 dataset。对于数据集的创建涉及到两个类, torch_geometric.data.Datasettorch_geometric.data.InMemoryDataset

,其中第二个是第一个的子类,如果希望全部数据都在内存里则需要使用第二个类。每个数据集需要提供文件夹路径作为参数,其中一个 raw_dir存储数据集的源文件,而另一个参数 processed_dir存储处理过的文件。

每个数据集都会经过 transform,pre_transform,pre_filter三个函数,默认是 None。第一个函数在使用前动态的转化数据对象(所以最好用于数据增强);第二个函数是将数据集存储在磁盘前的转换函数(最好用于仅需做一次的大量预计算任务);最后一个函数在存储前过滤一些对象。

创建"InMemoryDataset"

为了创建这个数据集,需要实现下面四个基本方法:

raw_file_namesraw_dir文件列表,如果源文件在这里存在的话,就可以跳过下载。

processed_file_names:在 processed_dir里的文件列表,用于跳过处理。

download:将源文件下载到 raw_dir里面。

process:处理源数据并保存到 processed_dir里。在这里面,需要读取并创建一个 Data对象列表存储到上面的文件夹里,但是python存储是慢的,因此我们在存储前通过 collate将list合为一个大的 Data对象,然后从这个对象返回一个 slices字典用于重构单个样例。最后我们需要加载两个对象 self.data, self.slices

对于其他更高级的方法参考torch_geometric.data

下面是一个简单的例子。

import torch
from torch_geometric.data import InMemoryDataset, download_url


class MyOwnDataset(InMemoryDataset):
    def __init__(self, root, transform=None, pre_transform=None):
        super().__init__(root, transform, pre_transform)
	# 初始化时将数据读入内存
        self.data, self.slices = torch.load(self.processed_paths[0])

    @property
    def raw_file_names(self):
        return ['some_file_1', 'some_file_2', ...]

    @property
    def processed_file_names(self):
        return ['data.pt']

    def download(self):
        # 将源文件下载到`self.raw_dir`.
        download_url(url, self.raw_dir)
        ...

    def process(self):
        # 读数据到大的 `Data` 列表.
        data_list = [...]

        if self.pre_filter is not None: # 首先过滤
            data_list = [data for data in data_list if self.pre_filter(data)]

        if self.pre_transform is not None: 
            data_list = [self.pre_transform(data) for data in data_list]

        data, slices = self.collate(data_list)
	# 将处理后的数据存储
        torch.save((data, slices), self.processed_paths[0])

创建一个更大的Datasets

对于更大的不能放在内存中的数据就需要用到 torch_geometric.data.Dataset,他还需要实现以下方法:

len:返回数据集中的样本数。

get:实现读取一个图的逻辑。

还有 __getitem__()方法从 get()中获取一个数据对象,并根据 transform选择性地转化他们。

下面是一个简单的例子:

import os.path as osp

import torch
from torch_geometric.data import Dataset, download_url


class MyOwnDataset(Dataset):
    def __init__(self, root, transform=None, pre_transform=None):
        super().__init__(root, transform, pre_transform)

    @property
    def raw_file_names(self):
        return ['some_file_1', 'some_file_2', ...]

    @property
    def processed_file_names(self):
        return ['data_1.pt', 'data_2.pt', ...]

    def download(self):
        # Download to `self.raw_dir`.
        path = download_url(url, self.raw_dir)
        ...
'''----------------------前面都是一样的---------------------'''
    def process(self): 
	# 这个函数是因为数据比较多,无法一次读入内存,所以以图为单位分开读取、处理、再存储
        idx = 0
        for raw_path in self.raw_paths:
            # 从 `raw_path`读取数据.
            data = Data(...)

            if self.pre_filter is not None and not self.pre_filter(data):
                continue

            if self.pre_transform is not None:
                data = self.pre_transform(data)

            torch.save(data, osp.join(self.processed_dir, f'data_{idx}.pt'))
            idx += 1


    def len(self):
        return len(self.processed_file_names)

    def get(self, idx):
        data = torch.load(osp.join(self.processed_dir, f'data_{idx}.pt'))
        return data

一些高频问题

  1. 如何跳过 downloadprocess的执行?
    不对这两个方法进行重载就好了~

    class MyOwnDataset(Dataset):
        def __init__(self, transform=None, pre_transform=None):
            super().__init__(None, transform, pre_transform)
    
  2. 我真的需要使用这些数据集的接口吗?

    不需要!仅仅是将Data合并为一个list,将他们传进 DataLoader 即可。

    from torch_geometric.data import Data
    from torch_geometric.loader import DataLoader
    
    data_list = [Data(...), ..., Data(...)]
    loader = DataLoader(data_list, batch_size=32)
    

    参考链接

    CREATING YOUR OWN DATASETS

  • 5
    点赞
  • 38
    收藏
    觉得还不错? 一键收藏
  • 13
    评论
评论 13
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值