用TensorFlow lite将MobileNet移植到Android设备上

分类目标为输入一张自行车图片,判断是山地车还是公路车。

第一步:

在百度图片分别爬取5000张山地车和公路车的图片,放于data/mountain和data/road两个文件夹下

第二步:
用TensorFlow自带的工具来fine-tuning训练mobilenet:
git clone https://github.com/tensorflow/tensorflow.git
cd tensorflow
CUDA_VISIBLE_DEVICES=0 python tensorflow/tensorflow/examples/image_retraining/retrain.py \
    --image_dir /media/bike/data/ \
    --output_graph /media/bike/model/output_graph.pb \
    --intermediate_output_graphs_dir /media/bike/model/ \
    --output_labels /media/bike/model/output_labels.txt \
    --summaries_dir /media/bike/retrain_logs \
    --model_dir /media/bike/model/ \
    --bottleneck_dir /media/bike/model/bottleneck \
    --learning_rate=0.0001 \
    --testing_percentage=20 \
    --validation_percentage=20 \
    --train_batch_size=128 \
    --validation_batch_size=-1 \
    --flip_left_right True \
    --random_scale=30 \
    --random_brightness=30 \
    --eval_step_interval=100 \
    --how_many_training_steps=1000 \
    --architecture mobilenet_1.0_224

第三步:
将生成的pb文件转换为tflite文件:

bazel run --config=opt \
     tensorflow/contrib/lite/toco:toco -- \
     --input_file=output_graph.pb \
     --input_format=TENSORFLOW_GRAPHDEF \
     --output_format=TFLITE \
     --output_file=nsfw.lite \
     --inference_type=FLOAT \
     --input_type=FLOAT \
     --input_arrays=input \
     --output_arrays=final_result \
     --input_shapes=1,224,224,3

在这里我遇到过两个问题:
  1. 一开始我用的是tensorflow.contrib.slim.python.slim.nets.resnet_v2模型,生成pb文件后转换成tflite文件后只有0KB,在GitHub上有人遇到了同样的问题,但作者还没回复。
2. 在本地编译运行bazel会有报错:external/gemmlowp/public/../internal/../internal/kernel_default.h:88:2: error: "SIMD not enabled, you'd be getting a slow software fallback. Consider enabling SIMD extensions (for example using -msse4 if you're on modern x86). If that's not an option, and you would like to continue with the slow fallback, define GEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK."
参考GitHub上firewu给出的方法就可以解决了。

第四步:
用TensorFlow官方给出的demo,将tflite文件和label.txt放在demo/app/src/main/assets/ 文件夹下,最后用Android studio打包成apk。
这里要注意到我们训练的模型接收的图片格式为float32,与官方的demo格式不一致,具体修改方法参考GitHub上的issue


最终移植到手机上的效果如下图:





参考资料:

  1. 知乎专栏 MobileNet教程:用TensorFlow搭建在手机上运行的图像分类器
  2. 知乎专栏 MobileNet教程(2):用TensorFlow做一个安卓图像分类App
  3. TensorFlow lite 文档
  4. Tensorflow Lite初探(Android)
  5. TensorFlow Lite学习笔记2:生成TFLite模型文件


阅读更多
换一批

没有更多推荐了,返回首页