pytorch load huge dataset(大数据加载)

问题

最近用pytorch做实验时,遇到加载大量数据的问题。实验数据大小在400Gb,而本身机器的memory只有256Gb,显然无法将数据一次全部load到memory。

解决方法

首先自定义一个MyDataset继承torch.utils.data.Dataset,然后将MyDataset的对象feed in torch.utils.data.DataLoader()即可。MyDataset在__init__中声明一个文件对象,然后在__getitem__中缓慢读取数据,这样就不会一次把所有数据加载到内存中了。训练数据存放在train.txt中,每一行是一条数据记录。

import torch.utils.data as Data
from tqdm import tqdm
class MyDataset(Data.Dataset):
	def __init__(self,filepath):
		number = 0
		with open(filepath,"r") as f:
			# 获得训练数据的总行数
			for _ in tqdm(f,desc="load training dataset"):
				number+=1
		self.number = number
		self.fopen = open(filepath,'r')
	def __len__(self):
		return self.number
	def __getitem__(self,index):
		line = self.fopen.__next__()
		# 自定义transform()对训练数据进行预处理
		data = transform(line)
		return data

train_dataset = MyDataset(filepath = "train.txt&
### PyTorch Geometric 使用教程和代码实例 #### 创建图数据集 为了创建自定义的数据集,在 `torch_geometric.data` 中继承 `InMemoryDataset` 类并重写一些必要的函数可以轻松完成此操作。除了基本的方法外,还需要实现特定于应用的逻辑来处理数据读取和预处理。 对于任何新的数据集类来说,至少要覆盖如下几个成员函数[^4]: - `_download()`:如果适用的话,下载原始文件到指定路径下。 - `_process()` :从原始输入转换成适合模型训练的形式,并保存至磁盘供后续快速访问。 - `len(self)`:返回整个集合里样本的数量。 - `get(self, idx)`:获取索引位置处的具体条目信息。 下面给出一段简单的例子展示如何构建自己的图数据库: ```python from torch_geometric.datasets import InMemoryDataset import os.path as osp import torch class MyOwnDataset(InMemoryDataset): def __init__(self, root, transform=None, pre_transform=None): super(MyOwnDataset, self).__init__(root, transform, pre_transform) self.data, self.slices = torch.load(self.processed_paths[0]) @property def raw_file_names(self): return ['some_file_1', 'some_file_2'] @property def processed_file_names(self): return ['data.pt'] def download(self): # Download to `self.raw_dir`. pass def process(self): # Read data into huge `Data` list. ... if self.pre_filter is not None: data_list = [data for data in data_list if self.pre_filter(data)] if self.pre_transform is not None: data_list = [self.pre_transform(data) for data in data_list] data, slices = self.collate(data_list) torch.save((data, slices), self.processed_paths[0]) ``` 这段代码展示了怎样通过子类化 `InMemoryDataset` 来建立一个新的图数据集对象。这里省略了一些具体细节比如实际的数据加载过程 (`process`) 和过滤/变换规则(`pre_filter`, `pre_transform`) 的设定;这些部分应该依据具体的项目需求而定。 #### 构建简单 GNN 模型 一旦有了准备好的数据之后就可以着手设计网络架构了。PyG 提供了许多内置模块可以直接拿来组合使用,同时也允许开发者自行扩展新组件以满足特殊的应用场合。以下是一份简易版 GCN (Graph Convolutional Network) 的搭建方式: ```python import torch.nn.functional as F from torch_geometric.nn import GCNConv class Net(torch.nn.Module): def __init__(self): super().__init__() self.conv1 = GCNConv(dataset.num_node_features, 16) self.conv2 = GCNConv(16, dataset.num_classes) def forward(self, data): x, edge_index = data.x, data.edge_index x = self.conv1(x, edge_index) x = F.relu(x) x = F.dropout(x, training=self.training) x = self.conv2(x, edge_index) return F.log_softmax(x, dim=1) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = Net().to(device) data = dataset[0].to(device) optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4) for epoch in range(200): model.train() optimizer.zero_grad() out = model(data) loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask]) loss.backward() optimizer.step() model.eval() _, pred = model(data).max(dim=1) correct = float(pred[data.test_mask].eq(data.y[data.test_mask]).sum().item()) acc = correct / data.test_mask.sum().item() print(f'Accuracy: {acc:.4f}') ``` 上述脚本首先导入所需的依赖项,接着定义了一个两层卷积结构作为分类器主体。每一层都接收节点特征矩阵以及边连接关系列表作为参数执行前向传播运算。最后利用交叉熵损失函数配合 Adam 优化算法迭代更新权重直至收敛为止。测试阶段则计算预测标签与真实类别之间的匹配度从而得出最终性能指标——准确率。
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值