Retrain a tensorflow model based on Inception v3

本文在谷歌2015_CVPR Inception v3模型的基础上,结合花朵识别的具体问题重新训练该模型,以获取自己需要的tensorflow模型。

重新训练Inception v3实质是在原有模型输出层后,新加了一个输出层作为最终的输出层,我们只训练这个新加的输出层。这里使用了迁移学习的概念。

Transfer learning, which means we are starting with a model that has been already trained on another problem. We will then be retraining it on a similar problem. Deep learning from scratch can take days, but transfer learning can be done in short order.

准备

本节主要给出了训练tensorflow模型的一些前提条件。

硬件环境

  • Ubuntu 16.04

安装tensorflow

安装git

$ sudo apt-get update
$ sudo apt-get install git

准备训练样本

$ cd ~
$ mkdir tf_files
$ cd tf_files
$ curl -O http://download.tensorflow.org/example_images/flower_photos.tgz
$ tar xzf flower_photos.tgz
$ ls flower_photos

flower_photos.tgz有218MB。

[可选操作]

$ cd ~/tf_files
$ ls flower_photos/roses | wc -l
$ rm flower_photos/*/[3-9]*  # 删除70%的样本数量,减少训练时间。
$ ls flower_photos/roses | wc -l

开始训练

下载retrain脚本

该脚本会自动下载google Inception v3 模型相关文件。

$ cd ~/tf_files
$ curl -O https://raw.githubusercontent.com/tensorflow/tensorflow/r1.1/tensorflow/examples/image_retraining/retrain.py

启动tensorboard

$ cd ~/tf_files
$ tensorboard --logdir training_summaries &

Note:
This command will fail with the following error if you already have a tensorboard running:
ERROR:tensorflow:TensorBoard attempted to bind to port 6006, but it was already in use
You can kill all existing TensorBoard instances with: $ pkill -f "tensorboard"

启动训练脚本

$ cd ~/tf_files
$ python retrain.py \
  --bottleneck_dir=bottlenecks \
  --how_many_training_steps=500 \
  --model_dir=inception \
  --summaries_dir=training_summaries/basic \
  --output_graph=retrained_graph.pb \
  --output_labels=retrained_labels.txt \
  --image_dir=flower_photos

如果不添加--how_many_training_steps=500,默认值为4000。

启动浏览器查看tensorboard

等待~/tf_files/bottlenecks中的bottlenecks文件生成结束后,可以启动浏览器,在地址栏中输入localhost:6006并回车,来查看训练进度。

小结

The retraining script will write out a version of the Inception v3 network with a final layer retrained to your categories to tf_files/retrained_graph.pb and a text file containing the labels to tf_files/retrained_labels.txt.
该图像识别模型,训练后的图像识别准确率应该在85%到99%。

测试重新训练的模型

$ cd ~/tf_files
$ curl -L https://goo.gl/3lTKZs > label_image.py
$ python label_image.py flower_photos/roses/2414954629_3708a1a04d.jpg 

你应该看到类似以下的结果:

daisy (score = 0.99071)
sunflowers (score = 0.00595)
dandelion (score = 0.00252)
roses (score = 0.00049)
tulips (score = 0.00032)

参考

TensorFlow For Poets

  • 0
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Digital2Slave

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值