前言
在新手搭建神经网络时,常弄不清epoch、batch_size、iteration和batch_idx(iteration )的区别。
这里以torchvision自带的CIFAR10数据集来举例,通过代码操作来直观地对这几个概念进行理解。
声明,这里batch_idx==iteration。
数据准备
首先加载数据集:
import torch
import torch.nn as nn
import torchvision
train_dataset = torchvision.datasets.CIFAR10(root="data/",train=True,download=False)
test_dataset = torchvision.datasets.CIFAR10(root=