tensorflow 模型与数据的存储与恢复

本文详细介绍了在TensorFlow中如何保存和恢复模型,包括完整模型的保存与恢复,部分变量的保存与恢复,以及检查点文件的内容查看。通过实例展示了在模型训练过程中遇到意外中断时,如何避免重复训练并有效利用已训练结果。
摘要由CSDN通过智能技术生成

我们在tensorflow中训练模型时,往往一个模型需要训练好几天,训练中可能出现意外关机而停止训练。如果此时再重头开始,必然使得前面已经训练的结果功亏一篑。于是,就有了save与restore。变量在文件中以name为名称存储

import tensorflow as tf
# 定义存储路径
save_path = "./test/model.ckpt"

# 存储变量
v1 = tf.get_variable(name="v1", shape=[3], initializer=tf.zeros_initializer)
v2 = tf.get_variable(name="v2", shape=[5], initializer=tf.zeros_initializer)
inc_v1 = v1.assign(v1+1)
dec_v2 = v2.assign(v2-1)
init_op = tf.global_variables_initializer()
saver = tf.train.Saver()
with tf.Session() as sess:
    sess.run(init_op)
    inc_v1.op.run()
    dec_v2.op.run()
    save_path = saver.save(sess,save_path=save_path) 

上面代码执行以后,将在save_path文件夹下生成四个文件,分别命名为:

ckeckpoint

model.ckpt.data-00000-of-00001

model.ckpt.index

model.ckpt.meta

此时,若全部注释掉上面的代码,再运行下面的代码,将会恢复所有变量

import tensorflow as tf
# 定义存储路径
save_path = "./test/model.ckpt"
tf.reset_default_graph()
v1 = tf.get_variable(name="v1", shape=[3], initializer=tf.zeros_initializer)
v2 = tf.get_variable(name="v2", shape=[5], initializer=tf.zeros_initializer)
saver = tf.train.Saver()
with tf.Session() as sess:
    saver.restore(sess,save_path)
    print(v1.eval(), v2.eval()) 

但是,如果我们多定义了一个变量v3,再运行如下代码,将会报错:

import tensorflow as tf
# 定义存储路径
save_path = "./test/model.ckpt"
tf.reset_default_graph()
v1 = tf.get_variable(name="v1", shape=[3], initializer=tf.zeros_initializer)
v2 = tf.get_variable(name="v2", shape=[5], initializer=tf.zeros_initializer)
v3 = tf.get_variable(name="v3", shape=[5], initializer=tf.zeros_initializer)
saver = tf.train.Saver()
with tf.Session() as sess:
    v3.initializer.run()
    saver.restore(sess,save_path)
    print(v1.eval(), v2.eval(),v3.eval())
错误为:NotFoundError (see above for traceback): Key v3 not found in checkpoint.即,名字为v3的变量找不到。也即是说,模型的变量数不能多于checkpoint中存储的变量数。

但是,如果我们只训练了一个五层网络,将这五层网络的变量存储到了checkpoint,现在我们要拓展网络到第六层,想用之前训练的五层网络的参数来初始化现在的第六层网络,怎么办呢?于是就遇到了我们需要存储与恢复部分变量的问题。我们可以以列表或者字典的形式在tf.train.Saver()中传递我们要存储与恢复的参数,接着上面代码中生成的checkpoint文件,继续运行如下代码:

tf.reset_default_graph()
v1 = tf.get_variable("v1", [3], initializer = tf.zeros_initializer)
v2 = tf.get_variable("v2", [5], initializer = tf.zeros_initializer)

# 也可以写作saver = tf.train.Saver([v2])
saver = tf.train.Saver({"v2":v2})
# Use the saver object normally after that.
with tf.Session() as sess:
  # 既然不需要从checkpoint文件中恢复v1,那么就需要手动初始化v1
  v1.initializer.run()
  saver.restore(sess,save_path)
  print("v1 : %s" % v1.eval())
  print("v2 : %s" % v2.eval())

输出 :

v1 : [0. 0. 0.]
v2 : [-1. -1. -1. -1. -1.]

可见,只有v2是从checkpoint中恢复的。

我们可以定义多个Saver()对象,也可以将同一个变量挂在多个Saver()对象里面。

有时候,我们想知道checkpoint里面存储了哪些变量,或者他们的值是多少,或者只关心给定的变量,可用如下代码:

from tensorflow.python.tools import inspect_checkpoint as chkp
chkp.print_tensors_in_checkpoint_file(save_path, tensor_name='', all_tensors=True,all_tensor_names=True)
print("***************************************")
chkp.print_tensors_in_checkpoint_file(save_path, tensor_name='v1', all_tensors=False,all_tensor_names=False)
print("***************************************")
chkp.print_tensors_in_checkpoint_file(save_path, tensor_name='v2', all_tensors=False,all_tensor_names=True) 

输出:

tensor_name:  v1
[1. 1. 1.]
tensor_name:  v2
[-1. -1. -1. -1. -1.]
***************************************
tensor_name:  v1
[1. 1. 1.]
***************************************
tensor_name:  v1

tensor_name:  v2

如果tensor_name="",当all_tensors=True时输出所有的tensors名及相应的取值。若all_tensors=False及all_tensor_names=True,则只输出所有的额变量名,不输出其取值。

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值