源码如下:
class TensorDataset(Dataset): """Dataset wrapping tensors. Each sample will be retrieved by indexing tensors along the first dimension. Arguments: *tensors (Tensor): tensors that have the same size of the first dimension. """ def __init__(self, *tensors): assert all(tensors[0].size(0) == tensor.size(0) for tensor in tensors) self.tensors = tensors def __getitem__(self, index): return tuple(tensor[index] for tensor in self.tensors) def __len__(self): return self.tensors[0].size(0)
可以看到它把之前的data_tensor 和target_tensor去掉了,输入变成了元组×tensors,只需将data和target直接输入到函数中就可以。
附一个例子: import torch import torch.utils.data as Data BATCH_SIZE = 5 x = torch.linspace(1, 10, 10) y = torch.linspace(10, 1, 10) torch_dataset = Data.TensorDataset(x, y) loader = Data.DataLoader( dataset=torch_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, ) for epoch in range(3): for step, (batch_x, batch_y) in enumerate(loader): print('Epoch: ', epoch, '| Step: ', step, '| batch x: ', batch_x.numpy(), '| batch y: ', batch_y.numpy())