红色石头的个人网站:
红色石头的个人博客-机器学习、深度学习之路www.redstonewill.com最近发现了一份不错的源代码,作者使用 PyTorch 实现了如今主流的卷积神经网络 CNN 框架,包含了 12 中模型架构。所有代码使用的数据集是 CIFAR。
项目地址:
https://github.com/BIGBALLON/CIFAR-ZOO
CNN 经典论文
该项目实现的是主流的 CNN 模型,涉及的论文包括:
1. CNN 模型(12 篇)
- (lenet) LeNet-5, convolutional neural networks
论文地址:http://yann.lecun.com/exdb/lenet/ - (alexnet) ImageNet Classification with Deep Convolutional Neural Networks
论文地址:https://papers.nips.cc/paper/4824-imagenet-classification-with-deep-convolutional-neural-networks - (vgg) Very Deep Convolutional Networks for Large-Scale Image Recognition
论文地址:https://arxiv.org/abs/1409.1556 - (resnet) Deep Residual Learning for Image Recognition
论文地址:https://arxiv.org/abs/1512.03385 - (preresnet) Identity Mappings in Deep Residual Networks
论文地址:https://arxiv.org/abs/1603.05027 - (resnext) Aggregated Residual Transformations for Deep Neural Networks
论文地址:https://arxiv.org/abs/1611.05431 - (densenet) Densely Connected Convolutional Networks
论文地址:https://arxiv.org/abs/1608.06993 - (senet) Squeeze-and-Excitation Networks
论文地址:https://arxiv.org/abs/1709.01507 - (bam) BAM: Bottleneck Attention Module
论文地址:https://arxiv.org/abs/1807.06514 - (cbam) CBAM: Convolutional Block Attention Module
论文地址:https://arxiv.org/abs/1807.06521 - (genet) Gather-Excite: Exploiting Feature Context in Convolutional Neural Networks
论文地址:https://arxiv.org/abs/1810.12348 - (sknet) SKNet: Selective Kernel Networks
论文地址:https://arxiv.org/abs/1903.06586
2. 正则化(3 篇)
- (shake-shake) Shake-Shake regularization
论文地址:https://arxiv.org/abs/1705.07485 - (cutout) Improved Regularization of Convolutional Neural Networks with Cutout
论文地址:https://arxiv.org/abs/1708.04552 - (mixup) mixup: Beyond Empirical Risk Minimization
论文地址:https://arxiv.org/abs/1710.09412
3. 学习速率调度器(2 篇)
- (cos_lr) SGDR: Stochastic Gradient Descent with Warm Restarts
论文地址:https://arxiv.org/abs/1608.03983 - (htd_lr) Stochastic Gradient Descent with Hyperbolic-Tangent Decay on Classification
论文地址:https://arxiv.org/abs/1806.01593
需求和使用
1. 需求
运行所有代码的开发环境需求为:
- Python >= 3.5
- PyTorch >= 0.4
- TensorFlow/Tensorboard
- 其它依赖项 (pyyaml, easydict, tensorboardX)
作者提供了一键安装、配置开发环境的方法:
pip
2. 模型代码
作者将所有的模型都存放在 model 文件夹下,我们来看一下 PyTorch 实现的 ResNet 网络结构:
# -*-coding:utf-8-*-
其它模型也一并能找到。
3. 使用
简单运行下面的命令就可以运行程序了:
## 1 GPU for lenet
我们使用 yaml 文件 config.yaml 保存参数,查看 ./experimets 中的任何文件以了解更多详细信息。您可以通过 tensorboard 中 tensorboard --logdir path-to-event --port your-port 查看训练曲线。培训日志将通过日志转储,请检查您工作路径中的 log.txt。
模型在 CIFAR 数据集上的结果
1. 12 种 CNN 模型:
2. 正则化
默认的数据扩充方法是 RandomCrop+RandomHorizontalLip+Normalize,而 √ 表示采用哪种附加方法。
PS:Shake_Resnet26_2X64d 通过剪切和混合达到 97.71% 的测试精度!很酷,对吧?
3. 不同的学习速率调度器
最后,再附上项目地址:
https://github.com/BIGBALLON/CIFAR-ZOO
更多AI干货请关注公众号:【AI有道】