【TensorFlow】模型持久化tf.train.Saver—上(八)

对于官方的MNIST的例子,训练完之后就直接退出了,并没有将训练得到的模型保存下来方便下次直接使用。为了让训练结果可以复用,需要将训练得到的神经网络持久化。

TensorFlow提供了一个非常简单的API来保存和还原一个神经网络模型,这个模型就是tf.train.Saver类。
TensorFlow通过tf.train.Saver类实现神经网络模型的保存和提取,tf.train.Saver对象saver的save方法将TensorFlow模型保存到指定路径中,saver.save(sess,”Model/model.ckpt”):

import tensorflow as tf
#import os
v1 = tf.Variable(tf.constant(1.0),name = 'v1')
v2 = tf.Variable(tf.constant(2.0),name = 'v2')

result = v1+v2

init = tf.global_variables_initializer()

saver = tf.train.Saver()

with tf.Session() as sess:
    sess.run(init)
    #一定将path定义为'Models/model.ckpt'的模式,文件夹/文件名.文件类型
    saver.save(sess,'Models/model.ckpt')

实际在这个文件目录下会生成三个人文件和一个checkpoint文件:

这里写图片描述
1、model.ckpt.meta储存TensorFlow计算图的结构;
2、model.ckpt.dataXXX ,保存了每一个变量的取值;
3、checkpoint文件,这个文件保存一个目录下所有的模型文件列表。

保存好的文件要调用采用加载的方式:

import tensorflow as tf  

v1 = tf.Variable(tf.constant(1.0), name="v1")  
v2 = tf.Variable(tf.constant(2.0), name="v2")  
result = v1 + v2  

saver = tf.train.Saver()  

with tf.Session() as sess:  
    saver.restore(sess, "./Model/model.ckpt") # 注意此处路径前添加"./"或者不加一样的效果
    print(sess.run(result)) # [ 3.] 

这段加载模型的代码和保存模型的代码基本上一样,在加载模型中也是定义了一个TensorFlow计算图上的所有运算,并声明一个tf.train.Saver类。两者唯一的差别在于加载模型中没有运行代码初始化的过程,而是将变量的值通过已经保存的模型加载进来。

如果不希望重复定义图上的运算,可以加载持久化的图,使用:
saver = tf.train.import_meta_graph(“Model/model.ckpt.meta”)

import tensorflow as tf  
#加载持久化的图  
saver = tf.train.import_meta_graph("Model/model.ckpt.meta")    
with tf.Session() as sess:  
    saver.restore(sess, "./Model/model.ckpt") # 注意路径写法  
    print(sess.run(tf.get_default_graph().get_tensor_by_name("add:0"))) # [ 3.]  

以上就是保存图的计算方法和加载的两种不同的方式,在下篇中具体应用部分加载、变量重命名、变量的滑动平均值等。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值