(Pytorch)零基础ResNet训练CIFAR-10数据集实时推理

在实现ResNet-50训练CIFAR-10数据的过程中,发现使用默认参数进行训练时只能达到65-80的准确率,并不满足我的预期要求,查了一些资料,发现可能是图片尺寸不满足卷积下采样倍数的原因。

代码:代码如果对您有帮助,请在GitHub上给一颗小星星!!!GitHub - zzbbzz626/ResNet-cifar10-real-time-inferenceContribute to zzbbzz626/ResNet-cifar10-real-time-inference development by creating an account on GitHub.https://github.com/zzbbzz626/ResNet-cifar10-real-time-inference

下面给出我的思路:

1 模型简介

ResNet是何恺明等人在2015年提出的,获得了ImageNet分类任务比赛中的第一名,阐述ResNet的论文《Deep Residual Learning for Image Recognition》获评为CVPR2016最佳论文。

2 数据集简介

CIFAR-10是更接近普适物体的彩色图像数据集,一共包含10个类别的RGB彩色图像。每个图片尺寸为32x32。每个类别有6000个图像,数据集中共有50000张训练图片和10000张测试图片。

模型训练实验

图3 改进说明图

列举原始模型参数设置如下:

表1 原始模型网络参数设置

Layer_name

Kernel_size

stride

padding

Conv1

7

2

3

(1)对原始模型进行训练

分别采用不同的batch_szie和learn_rate尝试优化,得出如表2的训练准确率,并给出通过tensorboard实时可视化训练过程图:

表2 原始模型训练准确率

epoch

batch_size

learn_rate

loss

train_acc

test_acc

10

64

0.1

0.9521

68%

65%

50

32

0.001(1~20)

0.0002(20~50)

0.244

87.41%

79.21%

200

32

0.001(1~20)

0.0002(20~50)

0.4413

93.58%

79.73%

图4 对应表2顺序的训练可视化图

从表2和图4中可以明确看出通过改变batch_size的大小和动态学习率的设计可以将训练准确率调高,但准确率并不高,并且从50轮训练增加到200轮训练后准确率并没有显著的提高。

(2)问题分析

从图4中可以看出导致准确率没有显著提高的原因并不是过拟合训练。因此为解决上述问题,从网络模型的输入开始进行分析。通过第2部分对CIFAR-10数据集的介绍可以知道,CIFAR数据集的图片大小为32x32。通过对图片经过Conv1层的过程进行模拟得到模拟图5。

图5  CIFAR图片Conv1层模拟图

从图5中可以看出采用原始模型参数的Conv1层时并不满足卷积过程的公式,因此初步推断影响准确率的原因可能是原始模型Conv1层的参数设计不是和CIFAR-10数据集图片的输入大小非常合适。

(3)模型修改

根据上述分析,针对Conv1层的模型参数进行修改,修改后的参数设置如表3所示:

表3 网络参数设置

Layer_name

Kernel_size

stride

padding

Conv1

3

1

1

图6  CIFAR图片修改后Conv1层模拟图

通过对图片经过修改参数后Conv1层的过程进行模拟得到模拟图6。可以看出Conv1层修改后的参数设计和CIFAR-10数据集图片的输入大小非常合适。

(4)对修改后的模型进行训练

分别采用不同的batch_szie和learn_rate尝试优化,得出如表4的训练准确率,并给出通过tensorboard实时可视化训练过程图:

表4 训练准确率

epoch

batch_size

learn_rate

loss

train_acc

test_acc

50

128

0.1

0.3577

86.94%

84.7%

200

128

0.1

6.0595x10-4

99.99%

95.44%

图7 对应表4顺序的训练可视化图

从表4和图7中可以明确看出经过修改参数后的模型进行训练可以达到非常高的精确度。

4 模型实时推理应用实验

为体现训练模型的在实际应用中的泛化能力,利用已经训练好的模型进行实时检测。由于CIFAR-10数据集类别的特殊性,并不能采用实体目标进行检测,因此本实验基于OpenCV采用外置摄像头对网页中随机搜索到的照片中的物体进行实时推理预测。

图8 猫预测图

 从图8可以看出训练模型可以实现外置摄像头的实时推理检测。至此实验部分全部完成。

5 总结

PS:总体实验思路灵感是来自GitHub博主的一份代码。因此将我的整体思路记录在博客中,与大家共同讨论学习进步!

 如果大家需要预训练模型可以在评论区留言。

  • 0
    点赞
  • 12
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值