torch中的data迭代方法
介绍
- 看代码的过程中不难发现,不同作者模型训练时的数据输入方法差别非常大。
torch
提供了统一的接口,通过迭代器实现数据和标签的读取,使用方便也利于阅读。
实现方法
-
导入
from torch.utils.data import Dataset, DataLoader
-
Dataset
torch
内置抽象类,无法实例化,通过继承并重写魔术方法实现
class MyDataset(Dataset): def __init__(self, filepath): xy = np.load(filepath) self.len = xy.shape[0] self.x_data = torch.from_numpy(xy[:, :-1]) self.y_data = torch.from_numpy(xy[:, [-1]]) def __getitem__(self, item): return self.x_data[item], self.y_data[item] def __len__(self): return self.len dataset = MyDataset('MyData.npy')
- 示例中,以读取numpy文件为例,通过重写
__getitem__
,__len__
方法,实现数据的随机读取
-
Dataloader
- 调用
dataset
实例,通过设定的参数可生成DataLoader
train_loader = DataLoader(dataset=dataset, batch_size=32, shuffle=True, num_workers=2)
- 调用
-
训练中调用数据
for i, data in enumerate(train_loader, 0): # x, y = data