创建自己的数据集 InMemoryDataset

创建自己的数据集

https://pytorch-geometric.readthedocs.io/en/latest/notes/create_dataset.html
按照torchvision约定,每个数据集都会传递一个根文件夹,该根文件夹指示应将数据集存储在何处。我们将根文件夹分为两个文件夹:raw_dir,将数据集下载到,和processed_dir,将处理后的数据集保存在。

torch_geometric.data.InMemoryDataset.process()
在这里,我们需要阅读并创建Data对象列表并将其保存到中processed_dir。由于保存庞大的python列表相当慢,因此我们在保存之前将列表整理为一个庞大的Data对象torch_geometric.data.InMemoryDataset.collate()。整理后的数据对象将所有示例连接到一个大数据对象中,此外,还返回了slices字典以从该对象中重构单个示例。最后**,我们需要将构造函数中的这两个对象加载到属性self.data和中self.slices。**

问题:真的需要使用这些数据集接口吗?
不!就像在常规PyTorch中一样,您无需使用数据集,例如,当您想动态创建合成数据而无需将其显式保存到磁盘时。在这种情况下,只需传递一个torch_geometric.data.Data包含对象的常规python列表,然后将它们传递给torch_geometric.data.DataLoader:

from torch_geometric.data import Data, DataLoader
data_list = [Data(...), ..., Data(...)]
loader = DataLoader(data_list, batch_size=32)

InMemoryDataset

CLASSInMemoryDataset(root=None, transform=None, pre_transform=None, pre_filter=None)
用于创建完全适合CPU内存的图形数据集的数据集基类。请参阅这里的相关教程。
root(字符串,可选)-应该保存数据集的根目录。
transform (callable, optional) -返回转换后的版本。数据对象将在每次访问之前进行转换。
pre_transform (callable, optional) 返回转换后的版本。数据对象在保存到磁盘之前将进行转换。
pre_filter (callable, optional) -返回一个布尔值,该值指示该数据对象是否应包括在最终数据集中。

collate(data_list)
将数据对象的python列表整理为torch_geometric.data.InMemoryDataset.的内部存储格式。

需要输入的数据是:
① _A.txt:存储邻接信息,具体来说,就是把很多个图的节点按顺序编号,然后存入节点之间的邻接信息。
② _graph_indicator.txt:第i行对应的值x 表示 第i个节点属于第x个图
③ _graph_lables.txt:第i行对应的值x 表示 第i个图属于第x个类
④ _node_attributes.txt:第i行表示 第i个节点的特征向量

举例:DD数据集

D&D 在蛋白质数据库的非冗余子集中抽取了了1178个高分辨率蛋白质,使用简单的特征,如二次结构含量、氨基酸倾向、表面性质和配体;其中节点是氨基酸,如果两个节点之间的距离少于6埃(Angstroms),则用一条边连接。(DD数据集中节点是没有标签的,节点只有特征)

import numpy as np
data = np.genfromtxt('waveform.txt',delimiter=',',skip_header=18)
  1. np.genfromtxt:数据文件的一种非常常见的文件格式是逗号分隔值(CSV)或相关格式,如TSV(制表符分隔值)。要将数据从这些文件读入Numpy数组,我们可以使用Numpy.genfromtxt函数。
    https://blog.csdn.net/weixin_41811657/article/details/84614818
  2. 区分横纵坐标 delimiter=‘,’ 以逗号为间隔
    delimiter: the str used to separate data. 横纵坐标以 ‘,’ 分割,因此给 delimiter 传入 ‘,’。
  3. skip_header: the number of lines to skip at the beginning of the file. 有用数据是从19行开始的,因此给 skip_header 传入 18。

① _A.txt:存储邻接信息,具体来说,就是把很多个图的节点按顺序编号,然后存入节点之间的邻接信息。
该数据集有1686092条边

print("Loading DD_A.txt") #从txt文件中读取邻接表(每一行可以看作一个坐标,即邻接矩阵中非0值的位置)  包含所有图的节点
adjacency_list = np.genfromtxt(os.path.join(data_dir, "DD_A.txt"),dtype=np.int64, delimiter=',') - 1
print('adjacency_list',adjacency_list[0:100])
print(len(adjacency_list))

在这里插入图片描述
在这里插入图片描述
② _graph_indicator.txt:第i行对应的值x 表示 第i个节点属于第x个图

print("Loading DD_node_labels.txt")#读取节点的特征标签(包含所有图) 每个节点代表一种氨基酸 氨基酸有20多种,所以每个节点会有一个类型标签 表示是哪一种氨基酸
node_labels = np.genfromtxt(os.path.join(data_dir, "DD_node_labels.txt"), dtype=np.int64) - 1
print('node_labels',node_labels[0:100])
print(len(node_labels))

在这里插入图片描述

③ _graph_lables.txt:第i行对应的值x 表示 第i个图属于第x个类

print("Loading DD_graph_indicator.txt")#每个节点属于哪个图
graph_indicator = np.genfromtxt(os.path.join(data_dir, "DD_graph_indicator.txt"), dtype=np.int64) - 1
print('graph_indicator',graph_indicator[0:100])
print(len(graph_indicator))

在这里插入图片描述

④ _node_attributes.txt:第i行表示 第i个节点的特征向量

print("Loading DD_graph_labels.txt")#每个图的标签 (2分类 01)
graph_labels = np.genfromtxt(os.path.join(data_dir, "DD_graph_labels.txt"), dtype=np.int64) - 1
print('graph_labels',graph_labels[0:100])
print(len(graph_labels))

在这里插入图片描述

def read_data(self):
        #解压后的路径
        data_dir = os.path.join(self.data_root, "DD")
        print("Loading DD_A.txt")
        #从txt文件中读取邻接表(每一行可以看作一个坐标,即邻接矩阵中非0值的位置)  包含所有图的节点
        adjacency_list = np.genfromtxt(os.path.join(data_dir, "DD_A.txt"),
                                       dtype=np.int64, delimiter=',') - 1
        print("Loading DD_node_labels.txt")
        #读取节点的特征标签(包含所有图) 每个节点代表一种氨基酸 氨基酸有20多种,所以每个节点会有一个类型标签 表示是哪一种氨基酸
        node_labels = np.genfromtxt(os.path.join(data_dir, "DD_node_labels.txt"), 
                                    dtype=np.int64) - 1
        print("Loading DD_graph_indicator.txt")
        #每个节点属于哪个图
        graph_indicator = np.genfromtxt(os.path.join(data_dir, "DD_graph_indicator.txt"), 
                                        dtype=np.int64) - 1
        print("Loading DD_graph_labels.txt")
        #每个图的标签 (2分类 0,1)
        graph_labels = np.genfromtxt(os.path.join(data_dir, "DD_graph_labels.txt"), 
                                     dtype=np.int64) - 1
        num_nodes = len(node_labels) #节点数 (包含所有图的节点)
        #通过邻接表生成邻接矩阵  (包含所有的图)稀疏存储节省内存(coo格式 只存储非0值的行索引、列索引和非0值)
        #coo格式无法进行稀疏矩阵运算
        sparse_adjacency = sp.coo_matrix((np.ones(len(adjacency_list)), 
                                          (adjacency_list[:, 0], adjacency_list[:, 1])),
                                         shape=(num_nodes, num_nodes), dtype=np.float32)
        print("Number of nodes: ", num_nodes)
        return sparse_adjacency, node_labels, graph_indicator, graph_labels

通过read_data函数,也就是我们想要得到的是(adjacency_list)sparse_adjacency, node_labels, graph_indicator, graph_labels这四个数组
Loading DD_A.txt
(1686092, 2)
Loading DD_node_labels.txt 读取节点的特征标签(包含所有图) 每个节点代表一种氨基酸
(334925,)
Loading DD_graph_indicator.txt 每个节点属于哪个图
(334925,)
Loading DD_graph_labels.txt 每个图的标签 (2分类 0,1)
(1178,)

Index文件

在这里插入图片描述
在这里插入图片描述
在80:10:10中将总数的图表分成3个(train, val和test)

# reading idx from the files
root_idx_dir = '/home/xx/data/benchmarkingdata/TUs/'
dataset_name = 'DD'
all_idx = {}
for section in ['train', 'val', 'test']:
    with open(root_idx_dir + dataset_name + '_'+ section + '.index', 'r') as f:
        reader = csv.reader(f)
        all_idx[section] = [list(map(int, idx)) for idx in reader]
a = all_idx['test']
print(len(a))     #10
  • 0
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
在使用PyTorch Geometric (pyg)创建自制数据集时,你可以继承`InMemoryDataset`类,并重写一些必要的方法来实现数据集的下载、处理和加载。首先,你需要定义一个类,例如`MyDataset`,并将其继承自`InMemoryDataset`。在构造函数中,你可以调用父类的构造函数,并在其中加载数据集。你可以使用`torch.load`函数加载数据,并将其赋值给`self.data`和`self.slices`。同时,你可以重写`raw_file_names`和`processed_file_names`方法来指定原始文件和处理后的文件的名称。在`download`方法中,你可以实现数据集的下载逻辑。在`process`方法中,你可以对数据集进行处理,并将处理后的数据保存到指定的路径。最后,在`collate`函数中,你可以将一个python列表形式的数据转换为`InMemoryDataset`内部存储数据的格式。你可以使用`data_list\[0\].__class__()`来创建一个空的数据对象,然后将列表中的每个数据对象转换为该类型。\[2\]\[3\] #### 引用[.reference_title] - *1* [在PyG上构建自己的数据集](https://blog.csdn.net/qq_32113189/article/details/126663738)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v91^koosearch_v1,239^v3^insert_chatgpt"}} ] [.reference_item] - *2* *3* [【PyG入门学习】四:构建自己的数据集](https://blog.csdn.net/TwT520Ly/article/details/105633847)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v91^koosearch_v1,239^v3^insert_chatgpt"}} ] [.reference_item] [ .reference_list ]

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值