tensorflow 保存和加载模型Saver的使用

1.     注意:如果保存模型和加载模型在两个.py文件中,在Spyder的Ipython console中保存完模型后,v1是'Variable:0',v2是'Variable_1:0',紧接着运行加载模型,此文件中的v1,v2就是'Variable_2:0'和'Variable_3:0',加载会出错,需关掉Ipython console后,重新运行使得modelLoad.py中的v1和v2分别是'Variable:0'和'Variable_1:0'才能运行成功。

保存模型

import tensorflow as tf  

#声明两个变量并计算它们的和
v1 = tf.Variable(tf.constant(1.0,shape=[1],name='v1'))
v2 = tf.Variable(tf.constant(2.0,shape=[1],name='v2'))
result = v1 + v2

init_op = tf.global_variables_initializer()
#声明tf.train.Saver类用于保存模型
saver = tf.train.Saver()

with tf.Session() as sess:
    sess.run(init_op)
    #将模型保存到model/model.ckpt文件
    saver.save(sess,'model/model1/model.ckpt')
加载模型

import tensorflow as tf

#使用和保存模型代码中一样的方式来声明变量
v1 = tf.Variable(tf.constant(1.0,shape=[1],name='v1'))
v2 = tf.Variable(tf.constant(2.0,shape=[1],name='v2'))
result = v1 + v2

#声明tf.train.Saver类用于保存模型
saver = tf.train.Saver()

with tf.Session() as sess1:
    #加载已经保存的模型,并通过已经保存的模型中变量的值来计算加法
    saver.restore(sess1,'model/model1/model.ckpt')
    print(sess1.run(result))

runfile('F:/python学习201712/TensorFlow/modelLoad.py', wdir='F:/python学习201712/TensorFlow')
INFO:tensorflow:Restoring parameters from model/model1/model.ckpt
[ 3.]

2. 如果不希望重复定义图上的运算,也可以直接加载已持久化的图

import tensorflow as tf

#如果不希望重新定义图上的运算,也可以直接加载已经持久化的图
saver = tf.train.import_meta_graph('model/model1/model.ckpt.meta')

with tf.Session() as sess:
    saver.restore(sess,'model/model1/model.ckpt')
    #通过张量的名称来获取张量
    print(sess.run(tf.get_default_graph().get_tensor_by_name('add:0')))

runfile('F:/python学习201712/TensorFlow/modelLoad2.py', wdir='F:/python学习201712/TensorFlow')
INFO:tensorflow:Restoring parameters from model/model1/model.ckpt
[ 3.]

3. tf.train.Saver支持在保存或者加载时给变量重命名

import tensorflow as tf
"""为了保存或者加载部分变量,在声明tf.train.Saver类时可以提供一个列表来指定需要保存
或者加载的变量。"""
#这里声明的变量名称和已经保存的模型中变量的名称不同
v1 = tf.Variable(tf.constant(1.0,shape=[1],name='other_v1'))
v2 = tf.Variable(tf.constant(2.0,shape=[1],name='other_v2'))

#如果直接使用tf.train.Saver()来加载模型会报变量找不到的错误。

#使用一个字典来重命名变量就可以加载原来的模型了。这个字典指定了原来名称为v1的变量现在
#加载到变量v1中(名称为other-v1),名称为v2的变量加载到变量v2中(名称为other-v2)
saver = tf.train.Saver({'v1':v1,'v2':v2})
4. 保存和加载滑动平均模型

import tensorflow as tf

"""保存滑动平均模型"""
v = tf.Variable(0, dtype=tf.float32, name='v')
#在没有申明滑动平均模型时只有一个变量v,所以下面的语句只会输出'v:0'
for variables in tf.global_variables():
    print(variables.name)
    print('='*60)

ema = tf.train.ExponentialMovingAverage(0.99)
maintain_average_op = ema.apply(tf.global_variables())

#在申明滑动平均模型之后,Tensorflow会自动生成一个影子变量
#v/ExponentialMoving Average,于是下面的语句输出
# 'v:0'和'v/ExponentialMovingAverage:0'
for variables in tf.global_variables():
    print(variables.name)

saver = tf.train.Saver()
with tf.Session() as sess:
    init_op = tf.global_variables_initializer()
    sess.run(init_op)
    
    sess.run(tf.assign(v,10))
    sess.run(maintain_average_op)
    #保存时,Tensorflow会将v:0和v/ExponentialMovingAverage:0两个变量都存下来
    saver.save(sess,'model/model2/model.ckpt')
    print(sess.run([v,ema.average(v)]))
#通过变量重命名直接读取变量的滑动平均值
import tensorflow as tf

v = tf.Variable(0, dtype=tf.float32, name='v')
#通过变量重命名将原来变量v的滑动平均值直接赋值给v
saver = tf.train.Saver({'v/ExponentialMovingAverage':v})
with tf.Session() as sess:
    saver.restore(sess,'model/model2/model.ckpt')
    print(sess.run(v))

#为了加载时重命名滑动平均变量,tf.train.ExponentialMovingAverage类提供了
#variables_to_restore函数来生成tf.train.Saver类所需要的变量重命名字典。
ema = tf.train.ExponentialMovingAverage(0.99)

#通过使用variables_to_restore函数可以直接生成上面代码中提供的字典
# {'v/ExponentialMovingAverage':v}
#以下代码会输出
#{'v/ExponentialMovingAverage':<tensorflow.python.ops.variable.Variable object
#  at 0x7ff6454ddc10>}
#其中后面的Variable类就代表了变量v
print(ema.variables_to_restore())

saver = tf.train.Saver(ema.variables_to_restore())
with tf.Session() as sess:
    saver.restore(sess,'model/model2/model.ckpt')
    print(sess.run(v))
runfile('F:/python学习201712/TensorFlow/modelLoad3.py', wdir='F:/python学习201712/TensorFlow')
INFO:tensorflow:Restoring parameters from model/model2/model.ckpt
0.0999999
{'v/ExponentialMovingAverage': <tf.Variable 'v:0' shape=() dtype=float32_ref>}
INFO:tensorflow:Restoring parameters from model/model2/model.ckpt
0.0999999










  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值