百度网盘Cifar获取地址
链接:https://pan.baidu.com/s/132yQGedau02Bw47fz75bYQ
提取码:bnvd
torch.utils.data.DataLoader理解
批训练,把数据变成一小批一小批数据进行训练。DataLoader就是用来包装所使用的数据,每次抛出一批数据。具体的理解我们来看一下下面的代码:
import torch
BATCH_SIZE = 5
x = torch.linspace(1, 10, 10)
y = torch.linspace(10, 1, 10)
torch_dataset = torch.utils.data.TensorDataset(x, y)
loader = torch.utils.data.DataLoader(
dataset=torch_dataset,
batch_size=BATCH_SIZE,
shuffle=True,
num_workers=2,
)
def show_batch():
for epoch in range(3):
for step, (batch_x, batch_y) in enumerate(loader):
print("steop:{}, batch_x:{}, batch_y:{}".format(step, batch_x, batch_y))
if __name__== '__main__':
show_batch()
现在我们来看一下它的输出: