变量保存与导入
tensorflow内置的参数导出和导入基本用法,用于保存训练好的模型参数
import tensorflow as tf
"""
变量声明,运算声明 例:w = tf.get_variable(name="vari_name", shape=[], dtype=tf.float32)
初始化op声明
"""
#创建saver对象,它添加了一些op用来save和restore模型参数
saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(init_op)
#训练模型。。。
#使用saver提供的简便方法去调用 save op
saver.save(sess, "save_path/file_name.ckpt") #file_name.ckpt如果不存在的话,会自动创建
#后缀可加可不加
现在,训练好的模型参数已经存储好了,我们来看一下怎么调用训练好的参数。变量保存的时候,保存的是变量名:value,键值对。restore的时候,也是根据key-value来进行的
import tensorflow as tf
"""
变量声明,运算声明
初始化op声明
"""
#创建saver 对象
saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(init_op)#在这里,可以执行这个语句,也可以不执行,即使执行了,初始化的值也会被restore的值给override
saver.restore(sess, "save_path/file_name.ckpt-???")
#会将已经保存的变量值resotre到变量中,自己看好要restore哪步的
如何restore变量的子集,然后使用初始化op初始化其他变量
#想要实现这个功能的话,必须从Saver的构造函数下手
saver=tf.train.Saver([sub_set])
init = tf.initialize_all_variables()
with tf.Session() as sess:
#这样你就可以使用restore的变量替换掉初始化的变量的值,而其它初始化的值不受影响
sess.run(init)
if restor_from_checkpoint:
saver.restore(sess,"file.ckpt")
# train
saver.save(sess,"file.ckpt")