TensorFlow学习笔记(2)——保存和加载训练模型参数

在 Tensorflow 的训练过程中,参数通过使用Saver 类来保存和加载模型。

Saver类可以处理图中元数据和变量数据的保存和加载。我们只需要告诉 Saver 类我们需要保存哪个图? 哪些变量?

默认情况下,Saver类能处理默认图中包含的所有变量。


保存模型:

#导入tensorflow

import  tensorflow  as  tf

#定义两个变量x1,x2

x1 = tf.Variable ( 5. ,  name = "x1" )
x2 = tf.Variable ( 20. ,  name = "x2" )
# 定义一个加法操作
y = tf.add ( x1,  x2 )
#创建一个Saver对象,默认情况下,Saver操作对象是默认graph里面所有的变量
saver_all = tf.train.Saver ( )
#创建一个Saver对象,只保存变量X2
saver_x2 = tf.train.Saver ( { "x2" : x2 } )
#Session默认处理默认graph
with  tf.Session ( )  as  sess:
  # 初始化所有变量 
  sess.run ( tf.global_variables_initializer ( ) )
  #保存所有变量到data_all.chkp
  saver_all.save ( sess,  './data_all.chkp' )
  #保存变量y到data_x2.chkp
  saver_x2.save ( sess,  './data_x2.chkp' )
  print ( "x1=",sess.run ( x1 ) ,", x2=",sess.run ( x2 ) , "y=", sess.run ( y ) )
  print("save variables complete!")


程序运行结果如下:

=================== RESTART: D:\Python_code\ML\tt\saver.py ===================
x1= 5.0 , x2= 20.0 y= 25.0
save variables complete!
>>> 

并且在当前save目录下产生四个文件:checkpoint、 data_all.chkp.data-00000-of-00001、data_all.chkp.index、data_all.chkp.meta。


加载模型:

import  tensorflow  as  tf
#定义两个变量x1,x2
x1 = tf.Variable ( 0. ,  name="x1" )
x2 = tf.Variable ( 0. ,  name="x2" )
# 定义一个加法操作
y = tf.add ( x1,  x2 )
#创建一个Saver对象
restore_all = tf.train.Saver ( )
#Session默认处理默认graph
with  tf.Session ( )  as  sess:
  #从data_all.chkp加载模型参数
  restore_all.restore ( sess,  './save/data_all.chkp' )
  print ( "x1=", sess.run ( x1 ), ", x2=", sess.run ( x2 ) ,"y=", sess.run ( y ) )
  print ( "restore variables complete!" )

程序运行结果如下:

================== RESTART: D:\Python_code\ML\tt\restore.py ==================
x1= 5.0 , x2= 20.0 y= 25.0
restore variables complete!
>>> 

可以看出,模型完全恢复了。




评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值