【PyG入门学习】四:构建自己的数据集

1.简介

虽然Pytorch-Geometric提供了很多官方数据集,但是当需要构建自己的数据集的时候,就需要对如何使用dataset基类构造自己的数据集有所了解。库中提供了两个构建数据集的基类:torch_geometric.data.Datasettorch_geometric.data.InMemoryDataset,其中torch_geometric.data.InMemoryDataset继承了torch_geometric.data.Dataset,表示是否将整个数据集加载到内存中。
根据torchvision的习惯,每一个数据集都需要指定一个根目录,根目录下面需要分为两个文件夹,一个是raw_dir,这个表示下载的原始数据的存放位置,另一个是processed_dir,表示处理后的数据集存放位置。
另外,每一个数据集函数都可以传递函数transformpre_transformpre_filter,默认为Nonetransform函数用于数据对象被加载使用之前进行的动态转换(一般用于数据增强);pre_transform函数将数据对象保存到磁盘以前进行的转换,也就是得到processed_dir内数据文件之前对其调用(一般用于只需要计算一次的复杂预处理过程);pre_filter函数在数据进行保存之前进行过滤。

2.创建一次读入内存的数据

构建torch_geometric.data.InMemoryDataset,需要重写(区分重载和重写)四个函数:
(1)torch_geometric.data.InMemoryDataset.raw_file_names()
存放raw_dir目录下所有数据文件名的字符串列表,用于下载时的检查过程(正如之前的文章提到的,数据集下载的时候会检测是否已经存在,避免重复下载,也就是如何避免自动下载的httperror的解决方案)。
(2)torch_geometric.data.InMemoryDataset.processed_file_names()
和(1)类似,存放processed_dir目录下的文件名的列表,用于检测是否已经存在(不会二次处理)。
(3)torch_geometric.data.InMemoryDataset.download()
下载数据到raw_dir目录下。
(4)torch_geometric.data.InMemoryDataset.process()
raw_dir下的数据进行处理并存储到processed_dir目录下。
因此,可以发现关键在于第四个函数的实现,函数内首先需要读取原始数据并创建一个torch_geometric.data.Data对象的列表,并存储到processed_dir目录下面。直接存储和使用这个python-list时间代价很高,所以在存储之前调用torch_geometric.data.InMemoryDataset.collate()函数将列表转换为一个torch_geometric.data.Data对象。处理后的数据被整合到了一个数据对象中(作为返回值),同时返回一个slices字典来获取到这个数据对象中单个数据,所以总结下来process过程一共分四步:

  1. 加载数据创建列表
  2. 进行各种处理过程
  3. 调用collate()函数
  4. 存储本地

最后在数据类的构造函数中加载数据集并赋值给self.dataself.slices

import torch
from torch_geometric.data import InMemoryDataset

class MyDataset(InMemoryDataset):
    def __init__(self, root, transform=None, pre_transform=None):
        # 数据的下载和处理过程在父类中调用实现
        super(MyDataset, self).__init__(root, transform, pre_transform)
        # 加载数据
        self.data, self.slices = torch.load(self.processed_paths[0])

    # 将函数修饰为类属性
    @property
    def raw_file_names(self):
        return ['file_1', 'file_2']

    @property
    def processed_file_names(self):
        return ['data.pt']

    def download(self):
        # download to self.raw_dir
        pass

    def process(self):
        data_list = [...]

        if self.pre_filter is not None:
            data_list = [data for data in data_list if self.pre_filter(data)]

        if self.pre_filter is not None:
            data_list = [self.pre_transform(data) for data in data_list]

        data, slices = self.collate(data_list)
        # 这里的save方式以及路径需要对应构造函数中的load操作
        torch.save((data, slices), self.processed_paths[0])

3.创建大规模数据

大数据集一般不会直接加载到内存中,这里构建数据集的时候需要继承父类torch_geometric.data.Dataset。在上面构建数据集时,重写了四个函数,此处还需要多实现两个函数:
(1)torch_geometric.data.Dataset.len()
返回数据集的文件个数。
(2)torch_geometric.data.Dataset.get()
实现对单个数据(图数据集的话一般是单个图)的加载逻辑。

import os.path as osp

import torch
# 这里就不能用InMemoryDataset了
from torch_geometric.data import Dataset

class MyDataset(Dataset):
    # 默认预处理函数的参数都是None
    def __init__(self, root, transform=None, pre_transform=None):
        super(MyDataset, self).__init__(root, transform, pre_transform)

    @property
    def raw_file_names(self):
        return ['file_1', 'file_2']

    @property
    def processed_file_names(self):
        # 一次无法加载所有数据,所以对数据进行了分解
        return ['data1.pt', 'data2.pt', 'data3.pt']

    def download(self):
        # Download to raw_dir
        pass

    def process(self):
        i = 0
        # 遍历每一个文件路径
        for raw_path in self.raw_paths:
            data = Data(...)

            if self.pre_filter is not None and not self.pre_filter(data):
                continue

            if self.pre_transform is not None:
                data = self.pre_transform(data)

            torch.save(data, osp.join(self.processed_dir, 'data_{}.pt'.format(i)))
            i += 1


    def len(self):
        return len(self.processed_file_names)

    def get(self, idx):
        data = torch.load(osp.join(self.processed_dir, 'data{}.pt',format(idx)))
        return data

4.相关细节

当我第一遍看完文档之后,心中还是存在很多疑惑的,第一,毕竟直接继承了一个父类,具体的流程是如何的,还不清楚,第二,没有亲自制作一个数据集,的确理解上存在模糊,下面对我个人的一些疑惑进行探索。

4.1具体的数据加载流程

这里的流程是指包括了个人定义的数据类内部的逻辑以及父类InMemoryDataset中的逻辑(先分析内存数据集):
1.对MyDataset实例化,此时调用类内构造函数__init__,先通过父类构造函数,再从本地加载数据,因此所有的关键操作都是在父类构造中发生的。
2.(在调用父类构造函数的时候,根据文档的官方例子我产生了两个疑惑,第一个是参数中没有传递pre_filter参数,但是后面为什么还要判断self.pre_filter,难道说默认的pre_filter不是None?而是父类中给了一个实现方式?第二个是参数中传递了transform,但是在重写的process函数并没有transform的过程,那么这个过程又是在哪里实现的呢?)在InMemoryDataset类中,构造函数为:

def __init__(self, root=None, transform=None, pre_transform=None,
             pre_filter=None):
    super(InMemoryDataset, self).__init__(root, transform, pre_transform,
                                          pre_filter)
    self.data, self.slices = None, None

其中transformpre_transformpre_filter都是函数句柄(callable),具体说明如下:
(1)transform接受参数类型为torch_geometric.data.Data并且返回一个转换后的版本(数据类型不变),在每一次数据加载到程序之前都会默认调用进行数据转换。
(2)pre_transform接收参数类型为torch_geometric.data.Data,返回转换后的版本,在数据被存储到硬盘之前进行转换(只发生一次)。
(3)pre_filter接受参数类型为torch_geometric.data.Data,返回布尔类型结果,相当于对原始数据的一个mask
可以看到InMemoryDataset中构造函数的参数,这三个函数参数都是None。这也就是解决了之前的第一个疑问,如果要用pre_filter,就必须传递该参数,否则为None
3.调用InMemoryDataset的父类Dataset的构造函数,其实此处就可以发现大部分的逻辑已经可以在Dataset类中看到了。先对之前的疑惑二进行解答何时调用transform,为什么在process中没有transform呢?

def __getitem__(self, idx):
    r"""Gets the data object at index :obj:`idx` and transforms it (in case
    a :obj:`self.transform` is given).
    In case :obj:`idx` is a slicing object, *e.g.*, :obj:`[2:5]`, a list, a
    tuple, a  LongTensor or a BoolTensor, will return a subset of the
    dataset at the specified indices."""
    if isinstance(idx, int):
        data = self.get(self.indices()[idx])
        data = data if self.transform is None else self.transform(data)
        return data
    else:
        return self.index_select(idx)

这一段代码是源码Dataset类中的函数,可以看到这个函数是根据索引获取部分数据,idx为索引目标,可以是列表、元组、LongTensor或者BoolTensor。可以看到只有在访问数据元素时,才会调用transform函数。
4.在Dataset的构造函数中,有这么几行代码:

if 'download' in self.__class__.__dict__.keys():
    self._download()

if 'process' in self.__class__.__dict__.keys():
    self._process()

此处调用下载函数和处理函数,而self._download()会调用self.download()process同理。
5.将处理好的数据存储到本地,然后再加载到程序中。
以上就是详细的处理流程了,值得注意的是,如果需要下载数据,利用request相关技术,需要自己重写download()函数;如果要对数据进行预过滤、转换和预转换,需要定义外部函数作为参数传递给构造过程。

4.2 实例学习

看了上面的内容,可能还是不知道咋做,现在就通过官方数据集的源码进行一波分析。例子以Planetoid为例:

from torch_geometric.datasets import Planetoid

1.构造函数中transformpre_transform都设置了None,但是没有pre_filter参数,也就是说这里不允许传递pre_filter参数。

def __init__(self, root, name, transform=None, pre_transform=None):
    self.name = name
    super(Planetoid, self).__init__(root, transform, pre_transform)
    self.data, self.slices = torch.load(self.processed_paths[0])

该数据集只有一个数据文件,所以直接取索引0。
2. 下载函数如下:

 def download(self):
     for name in self.raw_file_names:
          download_url('{}/{}'.format(self.url, name), self.raw_dir)

遍历每一个文件名,然后调用download_url函数进行下载。

from torch_geometric.data import download_url

不过在download_urlDataset类中的_download函数中都进行防覆盖检测。
3.处理函数如下:

def process(self):
    data = read_planetoid_data(self.raw_dir, self.name)
    data = data if self.pre_transform is None else self.pre_transform(data)
    torch.save(self.collate([data]), self.processed_paths[0])

第一步读取数据,第二步转换,第三步存储,主要是第一步的操作,这里调用了一个函数read_planetoid_data,此函数读取本地文件后,进行了训练集、测试集、验证集的划分,并且构造了一个Data对象:

data = Data(x=x, edge_index=edge_index, y=y)
data.train_mask = train_mask
data.val_mask = val_mask
data.test_mask = test_mask

在存储之前调用了

self.collate([data])

该函数的具体内容在下一小节中讲解。

4.3 collate函数

collate函数在InMemoryDataset中实现,将一个python列表形式数据转换(每一个元素都是一个数据对象)为torch_geometric.data.InMemoryDataset内部存储数据的格式。这里每一个数据对象未必是Data类型(一般代表一个Graph),也可以是其他的,比如图片等。

data = data_list[0].__class__()

这一行代码可以对第一个元素的类名解析并重新构造一个同类型元素。

for item, key in product(data_list, keys):
	data[key].append(item[key])

利用笛卡尔积构造元组替代双层循环,并且将列表中所有数据元素的值存放到一个数据对象中。后面的代码进行了一些拼接过程,具体的见Github

  • 19
    点赞
  • 64
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值