pytorch geometric 自定义数据集

Pytorch Geometric

一. torch_geometric.data.Data

pytorch Geometric Data使用邻接表去表示图,同时也表示了node特征x, 边属性edge_attr等, 需要注意的是, Data只表示一张图(single graph)

Data(x=None, edge_index=None, edge_attr=None, y=None

x: 表示节点特征,可选,shape: [num_nodes, num_node_features] 有的图只有结构没有节点特征

edge_index: 表示边,也就是邻接表, shape: [2, num_edges]

注意,因为能表示有向图, 对于无向图,一条边要存入两次,也就是位于节点1和节点2的边,需要写成[[1,2][2,1]]而不能只写入[[1],[2]]; node的编号和edge要对应,也就是 max_num_edges = num_nodesnum_nodes 而不是num_nodesnum_nodes /2

edge_attr: 表示边属性(e.g. , 权重,类型),shape: [num_edges, num_edge_features]

y: 是label,官方文档中说 Graph or node targets with arbitrary shape,所以shape可以是[num_nodes, nodes_label_dimension],或者是[graph_label_dimesnion]

二. 构建Dataset

pytorch geometric 构建数据集分两种

  1. 继承InMemoryDataset, 一次性加载所有数据到内存
import torch

from torch_geometric.data import InMemoryDataset

class MyOwnDataset(InMemoryDataset):

    def __init__(self, root, transform=None, pre_transform=None):

        super(MyOwnDataset, self).__init__(root, transform, pre_transform)

        self.data, self.slices = torch.load(self.processed_paths[0])

    @property

    def raw_file_names(self):

        return ['some_file_1', 'some_file_2', ...]

    @property

    def processed_file_names(self):

        return ['data.pt']

    def download(self):

        # Download to `self.raw_dir`.

    def process(self):

        # Read data into huge `Data` list.

        data_list = [...]

        if self.pre_filter is not None:

            data_list = [data for data in data_list if self.pre_filter(data)]

        if self.pre_transform is not None:

            data_list = [self.pre_transform(data) for data in data_list]

        data, slices = self.collate(data_list)

        torch.save((data, slices), self.processed_paths[0])

注意
1.如果需要在initial里面初始化一些参数,如定义mask,需要在super前继承参数
在这里插入图片描述
self.num_train_per_class 要放到super(NodeDatasetInMem, self)这一行前面

2.我们主要需要编辑def processed_file_names(self) 和 def process(self),

processed_file_names只需要申明把处理好的dataset存在哪里(路径加文件名)

process就是写一个函数,处理数据成torch_geometric.data.Data的形式,如果是图分类,还需要把多个图存成一个list

要注意x一般是float tensor, y 是 long tensor, mask 是boolean tensor, edge_index是long tensor

而且当y是graph label时, 不能是0-dimension tensor, 也就是说

y = torch.tensor(0, dtype=torch.long)#错
y = torch.tensor([0], dtype=torch.long)#对

3.其余函数作用

data, slices = self.collate(data_list)

torch.save((data, slices), self.processed_paths[0])

这个是官方代码里面的,作用就是通过self.collate把数据划分成不同slices去保存读取 (大数据块切成小块)

所以即使只有一个graph写成了data, 在调用self.collate时,也要写成list:

data, slices = self.collate([data])
  1. **继承Dataset, ** 分次加载到内存。
    直接继承torch_geometric.data.Dataset,除了和InMemoryDataset相似的函数以外,需要多写两个函数
torch_geometric.data.Dataset.len():

因为Dataset相对于InMemoryDataset,不会一次加载所有函数,而是分批,所有会把数据保存成好几个小数据包(.pt 文件),len() 就是说明有几个数据包,官方的指导:

def len(self):

        return len(self.processed_file_names)

可以完全照搬,只需要改变processed_file_names的返回值,例如
在这里插入图片描述
还有一个get() 函数

torch_geometric.data.Dataset.get():

这个函数需要返回值时一个data,single graph: Implements the logic to load a single graph

def get(self, idx):

        data = torch.load(osp.join(self.processed_dir, 'data_{}.pt'.format(idx)))

        return data

注意, 这里的load里面的函数名要和processed_file_name()返回的函数名一致, idx就是数据包的遍历下标

几个容易出问题的地方

  1. 继承InMemoryDataset时,在super继承之后,有一个读取数据的命令
    在这里插入图片描述
    由于继承Dataset, 有get函数load数据,所以写继承Dataset时不需要这条命令,否则会报错
  2. 不再调用self.collate() 去划分数据包, 也就没有data_list. 直接把一个个小数据包按照下标储存就好
    在这里插入图片描述
    以后看情况补足raw_file_names()和download()相关,不过本地数据可以不需要填充这两个函数

参考:
[1]: https://www.jianshu.com/p/6b9dccbceae4reference

  • 5
    点赞
  • 18
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值