第一步:保存模型的graph
完成这一步很简单,在测试脚本中的sess.run之前添加一行代码即可:
tf.train.write_graph(sess.graph_def, "./", 'test_graph.pb') #此处的sess是定义的tf.Session(), “./”代表保留到当前路径.
第二步:整理模型训练时保存的文件
tensorflow训练时会保存三个文件,后缀名分别是.data-00000-of-00001, .index 和 .meta。在这一步只需要将这三个文件和上一步保存的test_graph.pb放在同一个路径下。
第三步:安装bazel
我们待会需要用到tensorflow提供的freeze_graph工具进行模型固化,需要用bazel进行编译,安装详见官网: https://docs.bazel.build/versions/master/install-ubuntu.html .
第四步:编译freeze_graph
tensorflow官网对之后的步骤介绍得都非常清楚,详见:https://www.tensorflow.org/mobile/prepare_models
首先到github下载tensorflow的源码,然后进入文件夹编译,过程很简单:
bazel build tensorflow/python/tools:freeze_graph
第五步:固化模型
同样,按照官网的介绍,主要正确设置好路径和变量名就行了。需要注意的是,input_checkpoint的名字是tensorflow保存的三个文件的文件名除去后缀名后相同的那部分。
附官网代码:
bazel-bin/tensorflow/python/tools/freeze_graph \
--input_graph=/tmp/model/my_graph.pb \
--input_checkpoint=/tmp/model/model.ckpt-1000 \
--output_graph=/tmp/frozen_graph.pb \
--output_node_names=output_node \
至此,大功告成!