分类目标为输入一张自行车图片,判断是山地车还是公路车。
第一步:
在百度图片分别爬取5000张山地车和公路车的图片,放于data/mountain和data/road两个文件夹下
第二步:
用TensorFlow自带的工具来fine-tuning训练mobilenet:
git clone https://github.com/tensorflow/tensorflow.git
cd tensorflow
CUDA_VISIBLE_DEVICES=0 python tensorflow/tensorflow/examples/image_retraining/retrain.py \
--image_dir /media/bike/data/ \
--output_graph /media/bike/model/output_graph.pb \
--intermediate_output_graphs_dir /media/bike/model/ \
--output_labels /media/bike/model/output_labels.txt \
--summaries_dir /media/bike/retrain_logs \
--model_dir /media/bike/model/ \
--bottleneck_dir /media/bike/model/bottleneck \
--learning_rate=0.0001 \
--testing_percentage=20 \
--validation_percentage=20 \
--train_batch_size=128 \
--validation_batch_size=-1 \
--flip_left_right True \
--random_scale=30 \
--random_brightness=30 \
--eval_step_interval=100 \
--how_many_training_steps=1000 \
--architecture mobilenet_1.0_224
第三步:
将生成的pb文件转换为tflite文件:
bazel run --config=opt \
tensorflow/contrib/lite/toco:toco -- \
--input_file=output_graph.pb \
--input_format=TENSORFLOW_GRAPHDEF \
--output_format=TFLITE \
--output_file=nsfw.lite \
--inference_type=FLOAT \
--input_type=FLOAT \
--input_arrays=input \
--output_arrays=final_result \
--input_shapes=1,224,224,3
- 一开始我用的是tensorflow.contrib.slim.python.slim.nets.resnet_v2模型,生成pb文件后转换成tflite文件后只有0KB,在GitHub上有人遇到了同样的问题,但作者还没回复。
参考GitHub上firewu给出的方法就可以解决了。
第四步:
用TensorFlow官方给出的demo,将tflite文件和label.txt放在demo/app/src/main/assets/ 文件夹下,最后用Android studio打包成apk。
这里要注意到我们训练的模型接收的图片格式为float32,与官方的demo格式不一致,具体修改方法参考GitHub上的issue。
最终移植到手机上的效果如下图:
参考资料:
- 知乎专栏 MobileNet教程:用TensorFlow搭建在手机上运行的图像分类器
- 知乎专栏 MobileNet教程(2):用TensorFlow做一个安卓图像分类App
- TensorFlow lite 文档
- Tensorflow Lite初探(Android)
- TensorFlow Lite学习笔记2:生成TFLite模型文件