1. 下载tensorflow源码
git clone https://github.com/tensorflow/tensorflow
2. 安装bazel
apt-get install openjdk-8-jdk
echo "deb [arch=amd64] http://storage.googleapis.com/bazel-apt stable jdk1.8" | sudo tee /etc/apt/sources.list.d/bazel.list
curl https://bazel.build/bazel-release.pub.gpg | sudo apt-key add -
apt-get update
apt-get install bazel
apt-get upgrade bazel
3. 编译summarize_graph工具
cd tensorflow
bazel build tensorflow/tools/graph_transforms:summarize_graph
4. 查看pb模型的input_arrarys, input_shapes, output_arrarys等信息
bazel-bin/tensorflow/tools/graph_transforms/summarize_graph --in_graph=/tmp/output_graph.pb
5. pb转tflite
tflite_convert --graph_def_file=/tmp/output_graph.pb \
--output_file=/tmp/output_graph.tflite \
--input_format=TENSORFLOW_GRAPHEDF \
--output_format=TFLITE \
--input_shape=1,299,299,3 \
--input_array=Placeholder \
--output_array=final_result \
--inference_type=FLOAT \
--input_data_type=FLOAT
或者使用以下python转换
import tensorflow as tf
path="output_graph.pb"
inputs=["Placeholder"]
classes=["final_result"]
converter = tf.contrib.lite.TocoConverter.from_frozen_graph(path, inputs, classes)
tflite_model=converter.convert()
open("output_graph.tflite","wb").write(tflite_model)
6. Python测试
bazel build tensorflow/lite/examples/python:label_image
bazel-bin/tensorflow/lite/examples/python/label_image \
--image /home/test.jpg \
--model_file /tmp/output_graph.tflite \
--label_file /tmp/output_labels.txt
7. C测试(注:只支持bmp)
bazel build tensorflow/lite/examples/label_image:label_image
bazel-bin/tensorflow/lite/examples/label_image/label_image \
--image=/home/test.bmp \
--tflite_model=/home/output_graph.tflite \
--labels=/home/output_labels.txt