tensorflow 4——模型的保存、读取

版权声明:本文为博主原创文章,遵循 CC 4.0 by-sa 版权协议,转载请附上原文出处链接和本声明。
本文链接:https://blog.csdn.net/weixin_40056577/article/details/79334698

tf.train.Saver类为tensorflow的一个API
可通过import tensorflow as tf
help(tf.train.Saver)来查看这个API的用法

import tensorflow as tf
v1=tf.Variable(tf.constant(1.0,shape=[1]),name='v1')
v2=tf.Variable(tf.constant(2.0,shape=[1]),name='v2')
v3=tf.Variable(tf.constant(3.0,shape=[1]),name='v3')
result1=v1+v2
result2=v2+v3
print(result1,result2)

Tensor(“add:0”, shape=(1,), dtype=float32) Tensor(“add_1:0”, shape=(1,), dtype=float32)

可以看出两个张量的名称add:0和add_1:0。指的是加法的名称、次数以及一开始初始化的第一个值
接下来看如何保存图

saver=tf.train.Saver()
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    saver.save(sess,'/path/to/model.ckpt')  #然后就可以在路径下发现存储的文件
#接下来加载这些文件
import tensorflow as tf
saver1=tf.train.import_meta_graph('/path/to/model.ckpt.meta')#加载图
with tf.Session() as sess:
    saver1.restore(sess,'/path/to/model.ckpt') #将图上的数据加载进来
    print(sess.run(tf.get_default_graph().get_tensor_by_name('add:0')))

INFO:tensorflow:Restoring parameters from /path/to/model.ckpt
[ 3.]

若在上述会话中,print(result2)会出现result2 is not defined的问题。所以加载图和张量时,只能依靠张量的名称来获取其值,接下来看看如何重定义或加载本来保存在图里的张量v1,v2,result

a1=tf.Variable(tf.constant(3.0,shape=[1]),name='a1')
a2=tf.Variable(tf.constant(4.0,shape=[1]),name='a2')
saver=tf.train.Saver({'v1':a1,'v2':a2}) #将原本图上的变量v1,v2加载过来到新的张量a1,a2上,可以看到a1为1而不是3
with tf.Session() as sess:
    saver.restore(sess,'/path/to/model.ckpt')
    print(sess.run(a1))

INFO:tensorflow:Restoring parameters from /path/to/model.ckpt
[ 1.]
一般情况下,从保存的模型文件中加载计算原图meta,再restore所需变量值就可以达到调控模型参数的目的。
即模型每训练1000次保存一个模型,假若发现第4000次训练过拟合,第3000次训练的模型不太理想,则可加载restore第3000次模型的计算图以及变量,并在这基础上训练500次,来方便地实现调控模型训练次数的方式而又避免了重复训练,并可作多项研究。

saver指定要保存的变量,saver.save则指定在某个会话下,模型保存的路径,以及全局迭代的次数。
| saver = tf.train.Saver(…variables…)
| # Launch the graph and train, saving the model every 1,000 steps.
| sess = tf.Session()
| for step in xrange(1000000):
| sess.run(..training_op..)
| if step % 1000 == 0:
| # Append the step number to the checkpoint name:
| saver.save(sess, ‘my-model’, global_step=step)

#一个滑动平均类变量的保存
import tensorflow as tf
v=tf.Variable(0,dtype=tf.float32,name='v')
ema=tf.train.ExponentialMovingAverage(0.99)
print(ema.variables_to_restore())
{'v/ExponentialMovingAverage': <tf.Variable 'v:0' shape=() dtype=float32_ref>}

总结:
1.tf.train.Saver(…variables)
2.variables一般重命名加载,获取可由reader或者variables_to_restore()函数来获取相应的列表
3.保存的路径,以及global step的设置,默认情况下,保存的模型文件最多5个,可自行阅读参数修改
4.具体完整的应用在tensorflow代码梳理2的神经网络中

展开阅读全文

没有更多推荐了,返回首页