模型的持久化 tf.train.Saver()

模型持久化:将训练得到的模型保存下来方便下次直接使用。

代码实现:
保存模型

tensorflow提供一个 简单的api保存和还原一个 神经网络模型。
tf.train.Saver()

import tensorflow as tf

#声明两个变量并计算他们的和

v1 = tf.Variable(tf.constant(1.0, shape=[1], name = "v1")
v2 = tf.Variable(tf.constant(2.0, shape = [1], name= "v2")

result = v1 + v2

init_op = tf.initialize_all_variables()

#声明 tf.train.Saver类用于保存模型
saver = tf.train.Saver()

with tf.Session() as sess:
    sess.run(init_op)
    #将模型保存到/path/to/model/model.ckpt文件
    saver.save(sess, "/path/to/model/model/ckpt")

文件目录下将会出现三个文件:
第一个文件为model.ckpt.meta:保存了tensorflow 计算图结构
第二个文件为model.ckpt: 保存来程序中每一个变量的取值
最后一个文件为checkpoint: 保存了一个目录下所有模型的文件列表。

  • 加载模型
import tensorflow as tf

#使用和保存模型代码中一样的方式来声明变量
v1 = tf.Variable(tf.constant(1.0, shape = [1]), name = "v1")
v2 = tf.Variable(tf.constant(2.0, shape = [1]), name = "v2")

result =  v1 + v2

saver = tf.train.Saver()

with tf.Session() as sess:
    #加载已经保存的模型,并通过已经保存的模型中的变量的值来计算加法
    saver.restore(sess, "/path/to/model/model.ckpt")
    print(sess.run(result))

和保存模型的不同:未运行变量的初始化过程,而是将变量的值通过已经保存的模型加载进来。

如果不希望重复定义图上的运算,可以直接加载已经持久化的图。

import tensorflow as tf
#直接加载持久化的图
saver = tf.train.import_meta_graph(
    "/path/to/model/model.ckpt/model.ckpt.meta")
with tf.Session() as sess:
    saver.restore(sess, "/path/to/model/model.ckpt")
    #通过张量的名称来获取张量
    print(sess.run(tf.get_default_graph().get_tensor_by_name("add:0))
    #输出[3.]

tf.train.Saver() 类可以提供一个列表来制定需要保存或者加载的变量。
在加载模型的代码中 使用saver = tf.train.Saver([v1]) 命令构建tf.train.Saver类
,那么只有v1会被加载进来。
将会报错: Attempting to use uninitialized value v2
因为v2未加载。所以v2在运行初始化之前是没有值的。

  • tf.train.Saver类也支持在保存或者加载时加载时给变量重命名。
#这里声明的变量名称和已经保存的模型中的变量名称不同
v1 = tf.Variable(tf.constant(1.0, shape = [1]), name = 'other-v1')
v2 = tf.Variable(tf.constant(2.0, shape = [1]), name = 'other-v2')

#直接使用 tf.train.Saver() 来加载模型会报变量找不到的模型。
#NotFoundError: Tensor name "other-v2" #not found in checkpoint files /path/to#/model/model.ckpt

#使用一个字典(dictionary)来重命名变量可以就可以加载原来的模型了。这个字典制定来原来的名称为v1的变量现在加载到变量为v1中(名称为other-v1),名称为v2的变量加载到变量为v2中(名称为other-v2)

saver = tf.train.Saver({"v1":v1, "v2": v2})
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值