用tensorflow 的时候经常遇到模型的保存和加载问题,今天在使用的时候遇到了一点点小问题,经过训练保存的模型在加载的时候意外的出现了bug,仔细查找资料后才发现是自己对tensorflow的API不够熟练
模型的保存
我们一般训练的时候是在Session里训练的,所以保存的例如:
saver = tf.train.Saver(tf.global_variables(), max_to_keep=100)
with tf.Session as sess:
运行的代码
saver.save(sess, checkpoint_prefix, global_step=current_step)
save 函数的接口为:
save(
sess,
save_path,
global_step=None,
latest_filename=None,
meta_graph_suffix='meta',
write_meta_graph=True,
write_state=True
)
一般想要保存模型,我们需要定义一个对象:saver
通常是这样的:
saver = tf.train.Saver(tf.global_variables(), max_to_keep=100)
tf.train.Saver()
这个函数可以将模型的所有参数都保存下来,并且它支持保存部分参数。该Saver类添加OPS保存和恢复变量和检查站 。它还提供了方便的方法来运行这些操作。
检查点是一个专有格式的二进制文件,其映射变量名张量的值。检查检查点的内容,最好的办法是使用加载它Saver 。
储户可自动检查站数名与所提供的计数器。这可以让你保持多个检查点,在不同的步骤,而训练的模型。例如,你可以数随训练步数检查点的文件名。为了避免填满磁盘,储户自动管理检查点文件。例如,他们可以只保留N个最近的文件,或每N个小时的训练一个检查点。
a) Meta graph:
这是一个协议缓冲区(protocol buffer),它完整地保存了Tensorflow图;即所有的变量、操作、集合等。此文件以 .meta
为拓展名。
b) Checkpoint 文件:
这是一个二进制文件,包含weights、biases、gradients 和其他所有变量的值。此文件以 .ckpt 为扩展名. 但是,从Tensorflow 0.11
版本之后做出了一些改变。现在,不再是单一的.ckpt
文件,而是两个文件(.data和.index)
.data文件包含了我们的训练变量
存储的文件如下:
其中:checkpoint 存储着最近保存的文件列表,是一个文本文档,可以查看
model-464940.meta
保存的就是网络数据,这个其实只需要保存一次就行了,因为后面随着模型训练,网络是不会改变的,只有参数会改变。
model-463940.data-00000-of-00001
和 model-464940.index
是checkpoint 文件,保存我们训练的变量信息,464940
这个号码是系统根据你的设置生成的,例如你设置每1000步保存一次,那这个数字就会是1000,2000, 3000,...
注意事项
用saver保存模型的时候,saver.save(sess, checkpoint_prefix, global_step=current_step)
其中checkpoint_prefix 这个参数 是需要保存的名称前缀,而不是全称,例如上面我们的模型是model-464940.meta
那我们的checkpoint_prefix 就是dir/model
后面的都是API自己生成
模型加载
模型加载和模型保存是一致的看案例:
with tf.Session() as sess:
saver = tf.train.import_meta_graph({}.meta".format(model_file))
sess.run(tf.initialize_all_variables())
saver.restore(self.sess, model_file)
tf.train.import_meta_graph
直接加载meta
文件
saver.restore()
加载参数文件
记住,这个restore 不需要指定.data
或者.index
这两个文件都是需要的,所以只需要到模型的名称那就可以了