- TensorDataset 可以用来对 tensor 进行打包,就好像 python 中的 zip 功能。该类通过每一个 tensor 的第一个维度进行索引。因此,该类中的 tensor 第一维度必须相等. 另外:TensorDataset 中的参数必须是 tensor
- DataLoader就是用来包装所使用的数据,每次输出一批数据
import torch
from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader
a = torch.tensor([[1,2,3], [4,5,6], [7,8,9],[0,0,0]])
b = torch.tensor([11,22,33,44])
dataset = TensorDataset(a,b)
print(dataset[0:3]) # 切片输出
for x, y in dataset:
print(x, y)
# DataLoader进行数据封装
print('=' * 80)
dataloader = DataLoader(dataset, batch_size=2, shuffle=True)
for data in dataloader:
x, y = data
print(x,y)
结果: