前言
有时,我们需要保存tensorflow训练的模型:
- tf.train.write_graph()默认情况下只导出了网络的定义(没有权重)
- 利用tf.train.Saver().save()导出的文件graph_def与权重是分离的 为了方便使用模型,通过tensorflow.python.tools.freeze_graph可以将两者进行合并和优化最后得到最终的PB文件。
1.通过ckpt和tf.train.write_graph得到基础pb文件(无权重)
1.1在训练过程中使用使用tf.train.write_graph()以及tf.train.saver()生成pb文件和ckpt文件:
with tf.Session() as sess:
saver = tf.train.Saver()
saver.save(session, "model.ckpt")
tf.train.write_graph(session.graph_def, '', 'graph.pb')
1.2非训练过程中使用(加载网络生成pb文件):
import argparse
import logging
import tensorflow as tf
from tf_pose.networks import get_network
logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)s %(message)s')
config = tf.ConfigProto()
config.gpu_options.allocator_type = 'BFC'
config.gpu_options.per_process_gpu_memory_fraction = 0.95
config.gpu_options.allow_growth = True
if __name__ == '__main__':
"""
Use this script to just save graph and checkpoint.
While training, checkpoints are saved. You can test them with this python code.
"""
parser = argparse.ArgumentParser(description='Tensorflow Pose Estimation Graph Extractor')
parser.add_argument('--model', type=str, default='cmu', help='cmu / mobilenet_thin / my_mobilenet_thin')
args = parser.parse_args()
input_node = tf.placeholder(tf.float32, shape=(None, None, None, 3), name='image')
with tf.Session(config=config) as sess:
#网络结构以及ckpt参数
net, _, last_layer = get_network(args.model, input_node, sess, trainable=False)
#net:网络结构
#_ :ckpt参数
#last_layer:网络最后一层名称
print(last_layer)
tf.train.write_graph(sess.graph_def, 'models/graph/my_mobilenet_thin/', 'graph.pb', as_text=True)
关于函数get_network:
def get_network(type, placeholder_input, sess_for_load=None, trainable=True):
if type == 'mobilenet':
net = MobilenetNetwork({'image': placeholder_input}, conv_width=0.75, conv_width2=1.00, trainable=trainable)
pretrain_path = 'pretrained/mobilenet_v1_0.75_224_2017_06_14/mobilenet_v1_0.75_224.ckpt'
last_layer = 'MConv_Stage6_L{aux}_5'
return net, pretrain_path_full, last_layer
2.得到的pb文件与ckpt进行freezing:
$ python3 -m tensorflow.python.tools.freeze_graph \
--input_graph=... \
--output_graph=... \
--input_checkpoint=... \
--output_node_names="Openpose/concat_stage7"
参考资料: