最近在研究移动端的时候涉及到了模型的冰冻,即将训练得到的模型生成对应的pb文件。因此研究了一下tensorflow中的几种保存模型的方式,具体如下:
一、save保存。
save保存一定要在session中进行,并且save保存时会保存所有的参数信息,而这些信息是我们不一定需要的。并且save保存一般保存的是所有的网络和参数。
save的存储
tf.global_variables_initializer().run() # 初始化所有变量
saver=tf.train.Saver() # 参数为空,默认保存所有变量
saver=tf.train.Saver([w,b]) # 保存部分变量
saver.save(sess,logdir+'model.ckpt')
save的使用:
tf.global_variables_initializer().run() # 初始化所有变量
# 验证之前是否已经保存了检查点文件
ckpt = tf.train.get_checkpoint_state(logdir)
if ckpt and ckpt.model_checkpoint_path:
try:
saver = tf.train.Saver() # 参数为空,默认保存所有变量,这里只有变量w1、b1
saver.restore(sess, ckpt.model_checkpoint_path)
saver=None
except:
saver = tf.train.Saver([w1,b1]) # 参数为空,默认保存所有变量,这里只有变量w1、b1
saver.restore(sess, ckpt.model_checkpoint_path)
saver = None
注意在使用前要进行初始化。主要是对没有保存的变量赋值。另外此方法可以对前几层的变量不变,最后一层变量赋初始值从新训练。主要在借鉴前人的模型的时候可以使用。
二、write_grape:该方法主要是将图保存起来,保存结果只含图不含其它任何数据。以后遇到再补充。
三、convert_variabe_to_constants:该方法将图和数据一起保存。在保存时会将图中的变量取值以常量的形式保存。在保存模型时只保存了GraphDef部分,GraphDef保存了从输入层到输出层的计算过程。在保存时通过convert_variable_to_constants函数来指定保存的节点名称。具体使用方法为:
constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph_def,outpit_node_name=['name']])
with tf.gfile.FastGFile('path/name.pb', mode='wb') as f:
f.write(constant_graph.SerializeToString())
sess.close()
调用方法为:
detection_graph = tf.Graph()
with detection_graph.as_default():
od_graph_def = tf.GraphDef()
with tf.gfile.GFile(filename, 'rb') as fid:
serialized_graph = fid.read()
od_graph_def.ParseFromString(serialized_graph)
tf.import_graph_def(od_graph_def, name='')
# 模型运行
with detection_graph.as_default():
with tf.Session(graph=detection_graph) as sess:
导入之后:首先通过下面的语句来进行张量的名称。然后将张量的名称传给具体的节点。注意后面必须是0
input_x = sess.graph.get_tensor_by_name("input:0")
keep_prob = sess.graph.get_tensor_by_name("keep_prob:0")
最后链接几个关于存储的介绍
https://blog.csdn.net/c2a2o2/article/details/72778628
https://blog.csdn.net/sinat_29957455/article/details/78511119
https://blog.csdn.net/wc781708249/article/details/78039029