import torch
import torch.utils.data as Data
torch.manual_seed(1)# reproducible
BATCH_SIZE =5# BATCH_SIZE = 8
x = torch.linspace(1,10,10)# this is x data (torch tensor)
y = torch.linspace(10,1,10)# this is y data (torch tensor)# 先转换成 torch 能识别的 Dataset
torch_dataset = Data.TensorDataset(x, y)
loader = Data.DataLoader(
dataset=torch_dataset,# torch TensorDataset format
batch_size=BATCH_SIZE,# mini batch size
shuffle=True,# 要不要打乱数据 (打乱比较好)
num_workers=2,# 多线程来读数据)defshow_batch():for epoch inrange(3):# train entire dataset 3 times# enumerate是按序枚举列表,并添加序号for step,(batch_x, batch_y)inenumerate(loader):# for each training step# train your data...print('Epoch: ', epoch,'| Step: ', step,'| batch x: ',
batch_x.numpy(),'| batch y: ', batch_y.numpy())if __name__ =='__main__':
show_batch()