在使用PyTorch时,经常需要自己写DataLoader,写DataLoader前需要先写Dataset,通常情况可能需要给定一个file_path动态加载,这里先定义一个X和Y。
Dataset类需要实现__getitem__和__len__两个函数
from torch.utils.data.dataloader import DataLoader
from torch.utils.data.dataset import Dataset
import torch.utils.data as Data
import torch
class MyDataSet(Dataset):
def __init__(self, file_path):
super().__init__()
self.X = torch.randn(20)
self.Y = torch.linspace(1, 10, 20)
self._len = len(self.X)
pass
def __getitem__(self, idx):
return (self.X[idx], self.Y[idx])
def __len__(self):
return self._len
我们测试下刚刚的Dataset并建立DataLoader,DataLoader的好处是可以一次生成batch个数据,方便训练和计算。
data_set = MyDataSet('')
data_loder = DataLoader(dataset=data_set, batch_size=6, shuffle=True)
for batch in data_loader:
print(batch)
会得到类似这样的输出: ---------------------- [tensor([6., 4., 3.]), tensor([5., 7., 8.])] [tensor([ 2., 5., 10.]), tensor([9., 6., 1.])] [tensor([7., 1., 9.]), tensor([ 4., 10., 2.])] [tensor([8.]), tensor([3.])]
-------------------------------------
[tensor([7., 3., 9.]), tensor([4., 8., 2.])] [tensor([4., 6., 1.]), tensor([ 7., 5., 10.])] [tensor([ 5., 8., 10.]), tensor([6., 3., 1.])] [tensor([2.]), tensor([9.])]