可放在ulibs.py中作为常用功能库函数用
TensorFlow提供了convert_variables_to_constants函数,通过这个函数可以将计算图中的变量及其聚会通过常量的方式保存,这样整个TensorFlow计算图可以统一放在一个文件中。
模型最终要关注的是输入和输出,因此保存模型的时候通常就指定导出这两个节点就够了:
import tensorflow as tf from tensorflow.python.platform import gfile from tensorflow.python.framework import graph_util def model_save(sess, model_path, input_tensor_name, bottleneck_tensor_name): graph_def = tf.get_default_graph().as_graph_def() outpput_graph_def = graph_util.convert_variables_to_constants(sess, graph_def, [input_tensor_name, bottleneck_tensor_name]) with tf.gfile.GFile(model_path, "wb") as wf: wf.write(outpput_graph_def.SerializeToString()) def model_restore(model_path, input_tensor_name, bottleneck_tensor_name): with gfile.FastGFile(model_path, 'rb') as rf: graph_def = tf.GraphDef() graph_def.ParseFromString(rf.read()) in_tensor, out_tensor, = tf.import_graph_def(graph_def, return_elements=[input_tensor_name, bottleneck_tensor_name]) return in_tensor, out_tensor
参数:
model_path:指定了模型文件所在的路径;
input_tensor_name: 模型的输入张量名称;
bottelneck_tensor_name: 模型的瓶颈张量;
sess: 保存模型时需要传入当前的会话;