参考:
中文网站极客学院也有该部分的汉译版:http://wiki.jikexueyuan.com/project/tensorflow-zh/tutorials/deep_cnn.html
网友学习经验帖:https://blog.csdn.net/yhl_leo/article/details/50738311
版本报错修改:https://blog.csdn.net/zeuseign/article/details/72771598
目录
1.介绍
对CIFAR-10 数据集的分类是机器学习中一个公开的基准测试问题,其任务是对一组大小为32x32的RGB图像进行分类,这些图像涵盖了10个类别:飞机,汽车,鸟,猫,鹿,狗,青蛙,马,船以及卡车。
2.准备
代码:https://github.com/sjl3110/Cifar-10-TensorFlow1.8.0
文件 | 作用 |
---|---|
cifar10_input.py | 读取本地CIFAR-10的二进制文件格式的内容。 |
cifar10.py | 建立CIFAR-10的模型。 |
cifar10_train.py | 在CPU或GPU上训练CIFAR-10的模型。 |
cifar10_multi_gpu_train.py | 在多GPU上训练CIFAR-10的模型。 |
cifar10_eval.py | 评估CIFAR-10模型的预测性能。 |
本代码适用于1.8版本的tensorflow,因为API较老版本变化很大,因此对于官网给的代码做出了一定的修改。除此之外,迭代次数修改为20000次。由于笔记本训练,实际上1060的GPU下20000次迭代就花费了十几分钟的时间。
3.训练
直接运行,开始会下载CIFAR的数据包(官网链接http://www.cs.toronto.edu/~kriz/cifar.html,下载第三个),大概国内得3小时...然后没报错的话基本上就可以开始迭代学习了
python3 cifar10_train.py
4.评估
脚本文件cifar10_eval.py
对模型进行了评估,利用 inference()
函数重构模型,并使用了在评估数据集所有10,000张CIFAR-10图片进行测试。最终计算出的精度为1:N,N=预测值中置信度最高的一项与图片真实label匹配的频次。
为了监控模型在训练过程中的改进情况,评估用的脚本文件会周期性的在最新的检查点文件上运行,这些检查点文件是由cifar10_train.py
产生。
python3 cifar10_eval.py
命令运行后会显示当前的精度,一般同时运行训练和评估会消耗大量的系统资源。