创建自己的数据集
https://pytorch-geometric.readthedocs.io/en/latest/notes/create_dataset.html
按照torchvision约定,每个数据集都会传递一个根文件夹,该根文件夹指示应将数据集存储在何处。我们将根文件夹分为两个文件夹:raw_dir,将数据集下载到,和processed_dir,将处理后的数据集保存在。
torch_geometric.data.InMemoryDataset.process()
在这里,我们需要阅读并创建Data对象列表并将其保存到中processed_dir。由于保存庞大的python列表相当慢,因此我们在保存之前将列表整理为一个庞大的Data对象torch_geometric.data.InMemoryDataset.collate()。整理后的数据对象将所有示例连接到一个大数据对象中,此外,还返回了slices字典以从该对象中重构单个示例。最后**,我们需要将构造函数中的这两个对象加载到属性self.data和中self.slices。**
问题:真的需要使用这些数据集接口吗?
不!就像在常规PyTorch中一样,您无需使用数据集,例如,当您想动态创建合成数据而无需将其显式保存到磁盘时。在这种情况下,只需传递一个torch_geometric.data.Data包含对象的常规python列表,然后将它们传递给torch_geometric.data.DataLoader:
from torch_geometric.data import Data, DataLoader
data_list = [Data(...), ..., Data(...)]
loader = DataLoader(data_list, batch_size=32)
InMemoryDataset
CLASSInMemoryDataset(root=None, transform=None, pre_transform=None, pre_filter=None)
用于创建完全适合CPU内存的图形数据集的数据集基类。请参阅这里的相关教程。
root(字符串,可选)-应该保存数据集的根目录。
transform (callable, optional) -返回转换后的版本。数据对象将在每次访问之前进行转换。
pre_transform (callable, optional) 返回转换后的版本。数据对象在保存到磁盘之前将进行转换。
pre_filter (callable, optional) -返回一个布尔值,该值指示该数据对象是否应包括在最终数据集中。
collate(data_list)
将数据对象的python列表整理为torch_geometric.data.InMemoryDataset.的内部存储格式。
需要输入的数据是:
① _A.txt:存储邻接信息,具体来说,就是把很多个图的节点按顺序编号,然后存入节点之间的邻接信息。
② _graph_indicator.txt:第i行对应的值x 表示 第i个节点属于第x个图
③ _graph_lables.txt:第i行对应的值x 表示 第i个图属于第x个类
④ _node_attributes.txt:第i行表示 第i个节点的特征向量
举例:DD数据集
D&D 在蛋白质数据库的非冗余子集中抽取了了1178个高分辨率蛋白质,使用简单的特征,如二次结构含量、氨基酸倾向、表面性质和配体;其中节点是氨基酸,如果两个节点之间的距离少于6埃(Angstroms),则用一条边连接。(DD数据集中节点是没有标签的,节点只有特征)
import numpy as np
data = np.genfromtxt('waveform.txt',delimiter=',',skip_header=18)
- np.genfromtxt:数据文件的一种非常常见的文件格式是逗号分隔值(CSV)或相关格式,如TSV(制表符分隔值)。要将数据从这些文件读入Numpy数组,我们可以使用Numpy.genfromtxt函数。
https://blog.csdn.net/weixin_41811657/article/details/84614818 - 区分横纵坐标 delimiter=‘,’ 以逗号为间隔
delimiter: the str used to separate data. 横纵坐标以 ‘,’ 分割,因此给 delimiter 传入 ‘,’。 - skip_header: the number of lines to skip at the beginning of the file. 有用数据是从19行开始的,因此给 skip_header 传入 18。
① _A.txt:存储邻接信息,具体来说,就是把很多个图的节点按顺序编号,然后存入节点之间的邻接信息。
该数据集有1686092条边
print("Loading DD_A.txt") #从txt文件中读取邻接表(每一行可以看作一个坐标,即邻接矩阵中非0值的位置) 包含所有图的节点
adjacency_list = np.genfromtxt(os.path.join(data_dir, "DD_A.txt"),dtype=np.int64, delimiter=',') - 1
print('adjacency_list',adjacency_list[0:100])
print(len(adjacency_list))
② _graph_indicator.txt:第i行对应的值x 表示 第i个节点属于第x个图
print("Loading DD_node_labels.txt")#读取节点的特征标签(包含所有图) 每个节点代表一种氨基酸 氨基酸有20多种,所以每个节点会有一个类型标签 表示是哪一种氨基酸
node_labels = np.genfromtxt(os.path.join(data_dir, "DD_node_labels.txt"), dtype=np.int64) - 1
print('node_labels',node_labels[0:100])
print(len(node_labels))
③ _graph_lables.txt:第i行对应的值x 表示 第i个图属于第x个类
print("Loading DD_graph_indicator.txt")#每个节点属于哪个图
graph_indicator = np.genfromtxt(os.path.join(data_dir, "DD_graph_indicator.txt"), dtype=np.int64) - 1
print('graph_indicator',graph_indicator[0:100])
print(len(graph_indicator))
④ _node_attributes.txt:第i行表示 第i个节点的特征向量
print("Loading DD_graph_labels.txt")#每个图的标签 (2分类 0,1)
graph_labels = np.genfromtxt(os.path.join(data_dir, "DD_graph_labels.txt"), dtype=np.int64) - 1
print('graph_labels',graph_labels[0:100])
print(len(graph_labels))
def read_data(self):
#解压后的路径
data_dir = os.path.join(self.data_root, "DD")
print("Loading DD_A.txt")
#从txt文件中读取邻接表(每一行可以看作一个坐标,即邻接矩阵中非0值的位置) 包含所有图的节点
adjacency_list = np.genfromtxt(os.path.join(data_dir, "DD_A.txt"),
dtype=np.int64, delimiter=',') - 1
print("Loading DD_node_labels.txt")
#读取节点的特征标签(包含所有图) 每个节点代表一种氨基酸 氨基酸有20多种,所以每个节点会有一个类型标签 表示是哪一种氨基酸
node_labels = np.genfromtxt(os.path.join(data_dir, "DD_node_labels.txt"),
dtype=np.int64) - 1
print("Loading DD_graph_indicator.txt")
#每个节点属于哪个图
graph_indicator = np.genfromtxt(os.path.join(data_dir, "DD_graph_indicator.txt"),
dtype=np.int64) - 1
print("Loading DD_graph_labels.txt")
#每个图的标签 (2分类 0,1)
graph_labels = np.genfromtxt(os.path.join(data_dir, "DD_graph_labels.txt"),
dtype=np.int64) - 1
num_nodes = len(node_labels) #节点数 (包含所有图的节点)
#通过邻接表生成邻接矩阵 (包含所有的图)稀疏存储节省内存(coo格式 只存储非0值的行索引、列索引和非0值)
#coo格式无法进行稀疏矩阵运算
sparse_adjacency = sp.coo_matrix((np.ones(len(adjacency_list)),
(adjacency_list[:, 0], adjacency_list[:, 1])),
shape=(num_nodes, num_nodes), dtype=np.float32)
print("Number of nodes: ", num_nodes)
return sparse_adjacency, node_labels, graph_indicator, graph_labels
通过read_data函数,也就是我们想要得到的是(adjacency_list)sparse_adjacency, node_labels, graph_indicator, graph_labels这四个数组
Loading DD_A.txt
(1686092, 2)
Loading DD_node_labels.txt 读取节点的特征标签(包含所有图) 每个节点代表一种氨基酸
(334925,)
Loading DD_graph_indicator.txt 每个节点属于哪个图
(334925,)
Loading DD_graph_labels.txt 每个图的标签 (2分类 0,1)
(1178,)
Index文件
在80:10:10中将总数的图表分成3个(train, val和test)
# reading idx from the files
root_idx_dir = '/home/xx/data/benchmarkingdata/TUs/'
dataset_name = 'DD'
all_idx = {}
for section in ['train', 'val', 'test']:
with open(root_idx_dir + dataset_name + '_'+ section + '.index', 'r') as f:
reader = csv.reader(f)
all_idx[section] = [list(map(int, idx)) for idx in reader]
a = all_idx['test']
print(len(a)) #10