tensorflow通过tf.train.Saver()保存模型得到了四个文件:
checkpoint 检查点文件
model.ckpt.data-xxx 保存的是参数的值
model.ckpt.index 保存的是各个参数
model.ckpt.meta 保存的是图的结构
可通过saver.restore()恢复整个神经网络。
但是这种方式有几个缺点,首先这种模型文件是依赖 TensorFlow 的,只能在其框架下使用;其次,在恢复模型之前还需要再定义一遍网络结构,然后才能把变量的值恢复到网络中。
ps:
使用chpt也可以直接加载网络结构,不需要重新定义网络:
saver = tf.train.import_meta_graph('./ckpt/model.ckpt.meta')
谷歌推荐的保存模型的方式是保存模型为 PB 文件,它具有语言独立性,可独立运行,封闭的序列化格式,任何语言都可以解析它,它允许其他语言和深度学习框架读取、继续训练和迁移 TensorFlow 的模型。
它的主要使用场景是实现创建模型与使用模型的解耦, 使得前向推导 inference的代码统一。
另外的好处是保存为 PB 文件时候,模型的变量都会变成固定的,导致模型的大小会大大减小,适合在手机端运行。
这种 PB 文件是表示 MetaGraph 的 protocol buffer格式的文件。
Graphdef
中不保存任何 Variable 的信息,所以如果从 graph_def
来构建图并恢复训练的话,是不能成功的.
Meta Graph在具体实现上就是一个 MetaGraphDef
(同样是由 Protocol Buffer来定义的). 其包含了四种主要的信息,根据Tensorflow官网,这四种 Protobuf 分别是:
[1] - MetaInfoDef
,存一些元信息(比如版本和其他用户信息)
[2] - GraphDef
, MetaGraph 的核心内容之一
[3] - SaverDef
,图的Saver信息(比如最多同时保存的check-point数量,需保存的Tensor名字等,但并不保存Tensor中的实际内容)
[4] - CollectionDef
,任何需要特殊注意的 Python 对象,需要特殊的标注以方便import_meta_graph
后取回(如 train_op
, prediction
等等)
import tensorflow as tf
import os
from tensorflow.python.framework import graph_util
def save_mode_pb(pb_file_path):
x = tf.placeholder(tf.int32, name='x')
y = tf.placeholder(tf.int32, name='y')
b = tf.Variable(1, name='b')
xy = tf.multiply(x, y)
# 这里的输出需要加上name属性
op = tf.add(xy, b, name='op_to_store')
sess = tf.Session()
sess.run(tf.global_variables_initializer())
path = os.path.dirname(os.path.abspath(pb_file_path))
if os.path.isdir(path) is False:
os.makedirs(path)
# convert_variables_to_constants 需要指定output_node_names,list(),可以多个
constant_graph = graph_util.convert_variables_to_constants(sess, sess.graph_def, ['op_to_store'])
with tf.gfile.FastGFile(pb_file_path, mode='wb') as f:
f.write(constant_graph.SerializeToString())
# test
feed_dict = {x: 2, y: 3}
print(sess.run(op, feed_dict))
这个过程的主要思路是pb文件中没有包含网络中的Variable值(通常情况存储了权重),但是却包含了constant值,所以如果我们能把Variable转换为constant(使用graph_util.convert_variables_to_constants()函数),即可达到使用一个文件同时存储网络架构与权重的目标,所以在上述实例中在pb文件会保存b的值,但x和y的值并没有保存。
import tensorflow as tf
from tensorflow.python.platform import gfile
def restore_mode_pb(pb_file_path):
sess = tf.Session()
with gfile.FastGFile(pb_file_path, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
sess.graph.as_default()
tf.import_graph_def(graph_def, name='')
print(sess.run('b:0'))
input_x = sess.graph.get_tensor_by_name('x:0')
input_y = sess.graph.get_tensor_by_name('y:0')
op = sess.graph.get_tensor_by_name('op_to_store:0')
ret = sess.run(op, {input_x: 5, input_y: 5})
print(ret)
上述代码是从pb文件中恢复网络结构和variables的值