本章代码:https://github.com/zhangxiann/PyTorch_Practice/blob/master/lesson2/rmb_classification/
人民币 二分类
实现 1 元人民币和 100 元人民币的图片二分类。前面讲过 PyTorch 的五大模块:数据、模型、损失函数、优化器和迭代训练。
数据模块又可以细分为 4 个部分:
- 数据收集:样本和标签。
- 数据划分:训练集、验证集和测试集
- 数据读取:对应于PyTorch 的 DataLoader。其中 DataLoader 包括 Sampler 和 DataSet。Sampler 的功能是生成索引, DataSet 是根据生成的索引读取样本以及标签。
- 数据预处理:对应于 PyTorch 的 transforms
# DataLoader 与 DataSet
torch.utils.data.DataLoader()
torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=None, pin_memory=False, drop_last=False, timeout=0, worker_init_fn=None, multiprocessing_context=None)
功能:构建可迭代的数据装载器
- dataset: Dataset 类,决定数据从哪里读取以及如何读取
- batchsize: 批大小
- num_works:num_works: 是否多进程读取数据
- sheuffle: 每个 epoch 是否乱序
- drop_last: 当样本数不能被 batchsize 整除时,是否舍弃最后一批数据
Epoch, Iteration, Batchsize
- Epoch: 所有训练样本都已经输入到模型中,称为一个 Epoch
- Iteration: 一批样本输入到模型中,称为一个 Iteration
- Batchsize: 批大小,决定一个 iteration 有多少样本,也决定了一个 Epoch 有多少个 Iteration
假设样本总数有 80,设置 Batchsize 为 8,则共有 80 ÷ 8 = 10 80 \div 8=10 80÷8=10 个 Iteration。这里