Datawhale 图神经网络task7

引用自Datawhale

https://github.com/datawhalechina/team-learning-nlp/tree/master/GNN

前面文章:

  1. 图神经网络打卡task1
  2. Datawhale 图神经网络task2
  3. Datawhale 图神经网络task3
  4. Datawhale 图神经网络task4
  5. Datawhale 图神经网络task5
  6. Datawhale 图神经网络task6

超大规模数据集类的创建

在前面的学习中我们只接触了数据可全部储存于内存的数据集,这些数据集对应的数据集类在创建对象时就将所有数据都加载到内存。然而在一些应用场景中,数据集规模超级大,我们很难有足够大的内存完全存下所有数据。因此需要一个按需加载样本到内存的数据集类。在此上半节内容中,我们将学习为一个包含上千万个图样本的数据集构建一个数据集类。

Dataset基类简介

在PyG中,我们通过继承torch_geometric.data.Dataset基类来自定义一个按需加载样本到内存的数据集类。此基类与Torchvision的Dataset类的概念密切相关,这与第6节中介绍的torch_geometric.data.InMemoryDataset基类是一样的。

继承torch_geometric.data.InMemoryDataset基类要实现的方法,继承此基类同样要实现,此外还需要实现以下方法

  • len():返回数据集中的样本的数量。
  • get():实现加载单个图的操作。注意:在内部,__getitem__()返回通过调用get()来获取Data对象,并根据transform参数对它们进行选择性转换。

下面让我们通过一个简化的例子看继承torch_geometric.data.Dataset基类的规范

import os.path as osp

import torch
from torch_geometric.data import Dataset, download_url

class MyOwnDataset(Dataset):
    def __init__(self, root, transform=None, pre_transform=None):
        super(MyOwnDataset, self).__init__(root, transform, pre_transform)

    @property
    def raw_file_names(self):
        return ['some_file_1', 'some_file_2', ...]

    @property
    def processed_file_names(self):
        return ['data_1.pt', 'data_2.pt', ...]

    def download(self):
        # Download to `self.raw_dir`.
        path = download_url(url, self.raw_dir)
        ...

    def process(self):
        i = 0
        for raw_path in self.raw_paths:
            # Read data from `raw_path`.
            data = Data(...)

            if self.pre_filter is not None and not self.pre_filter(data):
                continue

            if self.pre_transform is not None:
                data = self.pre_transform(data)

            torch.save(data, osp.join(self.processed_dir, 'data_{}.pt'.format(i)))
            i += 1

    def len(self):
        return len(self.processed_file_names)

    def get(self, idx):
        data = torch.load(osp.join(self.processed_dir, 'data_{}.pt'.format(idx)))
        return data

其中,每个Data对象在process()方法中单独被保存,并在get()中通过指定索引进行加载。

跳过download/process

对于无需下载数据集原文件的情况,我们不重写(override)download方法即可跳过下载。对于无需对数据集做预处理的情况,我们不重写process方法即可跳过预处理。

无需定义Dataset类

通过下面的方式,我们可以不用定义一个Dataset类,而直接生成一个Dataloader对象,直接用于训练:

from torch_geometric.data import Data, DataLoader

data_list = [Data(...), ..., Data(...)]
loader = DataLoader(data_list, batch_size=32)

我们也可以通过下面的方式将一个列表的Data对象组成一个batch

from torch_geometric.data import Data, Batch

data_list = [Data(...), ..., Data(...)]
loader = Batch.from_data_list(data_list, batch_size=32)

图样本封装成批(BATCHING)与DataLoader

内容来源:ADVANCED MINI-BATCHING

合并小图组成大图

图可以有任意数量的节点和边,它不是规整的数据结构,因此对图数据封装成批的操作与对图像与序列等数据封装成批的操作不同。PyTorch Geometric中采用的将多个图封装成批的方式是,将小图作为连通组件(connected component)的形式合并,构建一个大图。于是小图的邻接矩阵存储在大图邻接矩阵的对角线上。大图的邻接矩阵、属性矩阵、预测目标矩阵分别为:
KaTeX parse error: No such environment: split at position 8: \begin{̲s̲p̲l̲i̲t̲}̲\mathbf{A} = \b…
此方法有以下关键的优势

  • 依靠消息传递方案的GNN运算不需要被修改,因为消息仍然不能在属于不同图的两个节点之间交换。

  • 没有额外的计算或内存的开销。例如,这个批处理程序的工作完全不需要对节点或边缘特征进行任何填充。请注意,邻接矩阵没有额外的内存开销,因为它们是以稀疏的方式保存的,只保留非零项,即边。

通过torch_geometric.data.DataLoader类,多个小图被封装成一个大图。

  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值