天主要分享两份 Github 项目,都是采用 PyTorch 来实现深度学习网络模型,主要是一些常用的模型,包括如 ResNet、DenseNet、ResNext、SENet等,并且也给出相应的实验结果,包含完整的数据处理和载入、模型建立、训练流程搭建,以及测试代码的实现。
接下来就开始介绍这两个项目。
- PyTorch Image Classification
这份代码目前有 200+ 星,主要实现以下的网络,在 MNIST、CIFAR10、FashionMNIST等数据集上进行实验。
使用方法如下:
然后就是给出作者自己训练的实验结果,然后和原论文的实验结果的对比,包括在训练设置上的区别,然后训练的迭代次数和训练时间也都分别给出。
之后作者还研究了残差单元、学习率策略以及数据增强对分类性能的影响,比如
类似金字塔网络的残差单元设计(PyramidNet-like residual units)
cosine 函数的学习率递减策略(Cosine annealing of learning rate)
Cutout
随机消除(Random Erasing)
Mixup
降采样后的预激活捷径(Preactivation of shortcuts after downsampling)
实验结果表明:
类似金字塔网络的残差单元设计有帮助,但不适宜搭配 Preactivation of shortcuts after downsampling
基于 cosine 的学习率递减策略提升幅度较小
Cutout、随机消除以及 Mixup 效果都很好,其中 Mixup 需要的训练次数更多
除了这个实验,后面作者还继续做了好几个实验,包括对 batch 大小、初始学习率大小