TensorDataset
TensorDataset本质上与python zip方法类似,对数据进行打包整合。
官方文档说明:
**Dataset wrapping tensors.
Each sample will be retrieved by indexing tensors along the first dimension.*
Parameters:
tensors (Tensor) – tensors that have the same size of the first dimension.
该类通过每一个 tensor 的第一个维度进行索引。因此,该类中的 tensor 第一维度必须相等。
import torch
from torch.utils.data import TensorDataset
# a的形状为(4*3)
a = torch.tensor([[1,1,1],[2,2,2],[3,3,3],[4,4,4]])
# b的第一维与a相同
b = torch.tensor