使用torch.utils.data.TensorDataset
来制作自己的数据集,
参考PyTorch: How to use DataLoaders for custom Datasets
.
import torch.utils.data as data_utils
train = data_utils.TensorDataset(features, targets)
train_loader = data_utils.DataLoader(train, batch_size=50, shuffle=True)
features为2D-Tensor,即需要将 3D-Tensor的RGB 图片或者是2D-Tensor的灰度图片拉伸成1D-Tensor,在使用的时候再还原。一缩一放之后的 tensor是否一致呢?
import torch
a = torch.rand(10,3,4,5)
b = a.view(10,-1)
c = b.view(10,3,4,5)
print(torch.equal(c, a))
# True
输出结果为 True
,前后缩放的 Tensor
一致。