tensorflow持久化

tensorflow持久化

tensorflow代码运行完后就会自动退出,不会主动保存本次训练结果,如果我们花了很大的精力训练了一次结果,而没有保存,下一次又需要这个模型的时候又要花费同样的尽力来再一次训练,这显然是不可接受的。
为了让训练结果可以重复使用,需要将训练得到的神经网络模型持久化。

持久化

tensorflow提供了一个非常简单的API来保存和还原一个神经网络模型,这个API就是 tf.train.Saver 类。

以下代码给出了保存Tensorflow计算图的方法

import tensorflow as tf

# 声明两个变量并计算它们的和
v1 = tf.Variable(tf.constant(1.0, shape=[1]), name="v1")
v2 = tf.Variable(tf.constant(2.0, shape=[1]), name="v2")
result = v1 + v2

init_op = tf.initialize_all_variables()
# 声明tf.train.Saver类用于保存模型
saver = tf.train.Saver()

with tf.Session() as sess:
    sess.run(init_op)
    # 将模型保存到/path/to/model/model.ckpt文件
    saver.save(sess, "/path/to/model/model.ckpt")

上面的代码持久化了一个简单的tensorflow模型的功能。在这段代码中,通过 saver.save 函数将Tensorflow模型保存到了 /path/to/model/model.ckpt 文件中。
tensorflow模型一般会保存在后缀为 .ckpt 的文件中,虽然这里只指定了一个文件路径,但是仔细到该目录下查看,会发现这里其实有3个文件:

checkpoint, model.ckpt, model.ckpt.meta

那么这三个文件各有什么意义呢?

model.ckpt: 保存了tensorflow程序中每一个变量的取值
model.ckpt.meta: 保存了tensorflow计算图的结构
checkpoint: 保存了一个目录下的所有模型列表文件

加载持久化模型

上面介绍了如何持久化模型,那么这里就讲一下如何加载一个持久化保存的模型。

import tensorflow as tf

# 使用和保存模型代码中一样的方式来声明变量
v1 = tf.Variable(tf.constant(1.0, shape=[1]), name="v1")
v2 = tf.Variable(tf.constant(2.0, shape=[1]), name="v2")
result = v1 + v2

saver = tf.train.Saver()

with tf.Session() as sess:
    # 加载已经保存的模型 并通过已经保存的模型中变量的值来计算加法
    saver.restore(sess, "/path/to/model/model.ckpt")
    print(sess.run(result))

这段代码和保存模型的代码基本一致。在加载模型的程序中也是先定义了Tensorflow计算图上所有的运算,并声明了一个 tf.train.Saver 类。
两段代码唯一不同的是,在加载模型的代码没有运行变量的初始化过程,而是将变量的值从已经保存的模型中加载出来。
如果不希望重复定义图上的运算,也可以直接加载已经持久化的图。

import tensorflow as tf
# 直接加载持久化的图
saver = tf.train.import_meta_graph("/path/to/model/model.ckpt/model.ckpt.meta")
with tf.Session() as sess:
    saver.restore(sess, "/path/to/model.ckpt")
    # 通过张量的名称来获取张量
    print(sess.run(tf.get_default_graph().get_tensor_by_name("add:0")))

在上面给出的例子程序中,默认加载了tensorflow计算图上定义的全部变量。但有时可能只需要报出或加载部分变量。
比如,可能有一个之前训练好的五层神经网络模型,但现在想尝试一个六层的神经网络,那么可以将前面的五层神经网络模型中的参数直接加载到新的模型,而仅仅将最后一层神经网络重新训练。

为了保存或加载部分变量,在声明 tf.train.Saver 类时可以提供一个列表来指定需要保存或加载的变量
比如在加载模型中的代码中使用 saver = tf.train.Saver([v1]) 命令来构建 tf.train.Saver 类,那么只有变量v1会被加载进来。
如果运行修改后只加载了v1的代码会的到其他变量未初始化的错误:
tensorflow.python.framework.errors.FailedPreconditionError:Attempting to use uninitialized value v2
因为v2没有被加载,所以v2在运行初始化之前是没有值的。除了可以选取需要被加载的变量,tf.train.Saver类也支持在保存或者加载时给变量重命名。

下面给出个简单的例子:

# 这里申明的变量名称和已经保存的模型中变量的名称不同
v1 = tf.Variable(tf.constant(1.0, shape=[1]), name='other-v1')
v2 = tf.Variable(tf.constant(2.0, shape=[1]), name='other-v2')

# 如果直接使用tf.train.Saver()来加载模型会报变量找不到的错误。

# 使用一个字典来重命名变量就可以加载原来的模型了。这个字典指定了原来名称为v1的变量加载到变量v1中(名称为other-v1),
# 名称为v2的变量加载到v2中(名称为other-v2)

saver = tf.train.Saver({'v1': v1, 'v2': v2})

在这个程序中,对变量v1和v2的名称进行了修改。如果直接通过tf.train.Saver默认的构造函数来加载保存的模型,那么程序会报变量找不到的错误。
因为变量在保存和加载时的名称不一致(这里指变量的name属性)。为了解决这个问题,TensorFlow可以通过字典将模型保存时的变量名和需要加载的变量联系起来

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值