PyG教程(4):自定义数据集

一.前言

在PyG中,除了直接使用它自带的benchmark数据集外,用户还可以自定义数据集,其方式与Pytorch类似,需要继承数据集类。PyG中提供了两个数据集抽象类:

  • torch_geometric.data.Dataset:用于构建大型数据集(非内存数据集);
  • torch_geometric.data.InMemoryDataset:用于构建内存数据集(小数据集),继承自Dataset

下面是对其的详细介绍。

二.内存数据集

2.1 创建说明

在PyG中要构建自己的内存数据集需要先继承InMemoryDataset类,并实现如下方法:

  • raw_file_names():返回原始数据集的文件名列表,若self.raw_dir中没有该列表中的文件,则会通过download()进行下载;
  • processed_file_names():返回process()方法处理后的文件名列表,若self.processed_dir中没有确实该列表中的文件,则需要通过process()方法进行处理;
  • download():下载原始数据集到self.raw_dir中;
  • process():处理原始数据集,并保存到processed_dir中。

在前两个方法中,若只有单个文件,则直接返回文件字符串即可,不一定要返回list对象。

另外,上面的self.raw_dirself.processed_dir其实是两个方法,其源码为:

# 加上@property,可以使得方法像属性一样被调用
@property
def raw_dir(self) -> str:
    return osp.join(self.root, 'raw')

@property
def processed_dir(self) -> str:
    return osp.join(self.root, 'processed')

从源码可以看出,self.raw_dirself.processed_dir是给定保存路径root下的原始数据文件夹和处理后的数据文件夹的路径。

2.2 创建演示

本文以SNAP数据集中的一个社交网络Facebook为例,来演示如何创建一个InMemoryDataset数据集FaceBook,该数据集包含4039个节点、88234条边。利用Gephi对该网络进行可视化如下:

facebook

根据3.1节中的说明,下面是自定义FaceBook类的源码:

import os
import pandas as pd
import torch
from torch_geometric.data import Data
from torch_geometric.data import InMemoryDataset, download_url, extract_gz


class FaceBook(InMemoryDataset):
    url = "https://snap.stanford.edu/data/facebook_combined.txt.gz"

    def __init__(self,
                 root,
                 transform=None,
                 pre_transform=None,
                 pre_filter=None):
        super().__init__(root, transform, pre_transform, pre_filter)
        self.data, self.slices = torch.load(self.processed_paths[0])

    @property
    def raw_file_names(self):
        return ["facebook_combined.txt"]

    @property
    def processed_file_names(self):
        return "data.pt"

    def download(self):
        path = download_url(self.url, self.raw_dir)
        extract_gz(path, self.raw_dir)

    def process(self):
        # 加载原始数据文件
        path = os.path.join(self.raw_dir, "facebook_combined.txt")
        edges = pd.read_csv(path, header=None,
                            delimiter=" ").values.reshape(2, -1)
      	# 构建Data对象
        edge_index = torch.from_numpy(edges)
        g = Data(edge_index=edge_index, num_nodes=4039)
        data, slices = self.collate([g])
        torch.save((data, slices), self.processed_paths[0])


if __name__ == "__main__":
    dataset = FaceBook(root="tmp")
    data = dataset[0]
    print(data.num_edges, data.num_nodes)
	# 88234 4039

需要注意的是

  • downloadprocess只在第一次调用时会调用,之后会直接加载处理好的数据集。
  • 以上4个方法并不都是需要的,例如如果你本地已经有了数据集,就不需要重写download()函数来下载原始数据集。

三.大型数据集

对于大型图数据集,需要继承Dataset类,除了InMemoryDataset中需要重写的4个方法外,还需重写如下方法:

  • len(): 返回数据集中实例的数量;
  • get():加载单个图的逻辑。

由于自定义大型数据集与InMemoryDataset类似,具体演示略。

四.结语

参考资料:

自定义数据集是一项重要的事情,尤其是当你本地有些数据需要转换为PyG中标准的图数据集的时候。

  • 5
    点赞
  • 20
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

斯曦巍峨

码文不易,有条件的可以支持一下

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值