在TensorFlow中保存已经训练好的神经网络模型

通常,训练一个具有一定实用价值的深度神经网络是非常消耗计算时间的。所以在使用时,最好的方法是导入已经训练好的模型,重用它,而不是每次都重新训练。


如果要在TensorFlow中保存已经训练好的神经网络模型,所需的核心方法就是Saver.save,它位于Saver类中:

  1. Saver类提供了向checkpoints文件保存和从checkpoints文件中恢复变量的相关方法。Checkpoints文件是一个二进制文件,它把变量名映射到对应的tensor值 。
  2. 只要提供一个计数器,当计数器触发时,Saver类可以自动的生成checkpoint文件。这让我们可以在训练过程中保存多个中间结果。例如,我们可以保存每一步训练的结果。
  3. 为了避免填满整个磁盘,Saver可以自动的管理Checkpoints文件。例如,我们可以指定保存最近的N个Checkpoints文件。

将训练好的模型参数保存起来,以便以后进行验证或测试,这是我们经常要做的事情。模型保存,先要创建一个Saver对象,如:

saver=tf.train.Saver()
在创建这个Saver对象的时候,有一个参数我们经常会用到,就是 max_to_keep 参数,这个是用来设置保存模型的个数,默认为5,即 max_to_keep=5,保存最近的5个模型。如果你想每训练一代(epoch)就想保存一次模型,则可以将 max_to_keep设置为None或者0,如:
saver=tf.train.Saver(max_to_keep=0)

但是这样做除了多占用硬盘,并没有实际多大的用处,因此不推荐。

当然,如果你只想保存最后一代的模型,则只需要将max_to_keep设置为1即可,即

saver=tf.train.Saver(max_to_keep=1)
创建完saver对象后,就可以保存训练好的模型了,如:
saver.save(sess,'folder_for_nn/save_net.ckpt',global_step=step)

第一个参数sess,这个就不用说了。第二个参数设定保存的路径和名字,第三个参数将训练的次数作为后缀加入到模型名字中。

saver.save(sess, 'my-model', global_step=0) ==>      filename: 'my-model-0'
...
saver.save(sess, 'my-model', global_step=1000) ==> filename: 'my-model-1000'

来看一段简单的示例代码:

import tensorflow as tf
import numpy as np

## Save to file
# remember to define the same dtype and shape when restore
W = tf.Variable([[2,5,7],[11,13,19]], dtype=tf.float32, name='weights')
b = tf.Variable([[23,29,31]], dtype=tf.float32, name='biases')

# initialization
init = tf.global_variables_initializer()

saver = tf.train.Saver()
特别地,需要指明你想要存储的位置:
with tf.Session() as sess:
    sess.run(init)
    save_path = saver.save(sess, "folder_for_nn/save_net.ckpt")
    print("Save to path: ", save_path)
假设我们的程序文件和你要保存模型参数文件的文件夹在同一个路径下,如下图所示(你可以事先建好这个文件夹,或者让TF来自动创建它,至少在我的实验中发现TF可以自行创建这个路径 ):

当你执行完程序后,程序显示模型参数已经被存入指定文件夹中。

这时你会发现原本空的文件夹里已经有了内容,如下所示:

下面要做的事情是,在需要使用这个模型的另外一个程序中,直接读入模型。之前版本的TensorFlow只能存储用于描述模型的参数,而不能存储神经网络的结构特征(现在不知道是否已经有了这种功能)。但是,如果模型是由你训练出来的,其实只需要把训练模型时用到的网络结构直接照搬即可。毕竟,参数才是定义和描述神经网络模型的核心了。


执行下面的代码。注意其中用于导入模型参数的核心方法是Saver.restore。

import tensorflow as tf
import numpy as np

# 先建立 W, b 的容器
W = tf.Variable(np.arange(6).reshape((2, 3)), dtype=tf.float32, name="weights")
b = tf.Variable(np.arange(3).reshape((1, 3)), dtype=tf.float32, name="biases")

#Notice: init = tf.global_variables_initializer() is unnecessary
saver = tf.train.Saver()
with tf.Session() as sess:
    # 提取变量
    saver.restore(sess,tf.train.latest_checkpoint('folder_for_nn'))
    print("weights:", sess.run(W))
    print("biases:", sess.run(b))

最开始我们的为模型参数变量灌入的值,在读取文件成功后,已经被替换成了之前存储的模型产生。如下面的执行结果所示。


参考文献及推荐阅读材料:

http://www.cnblogs.com/denny402/p/6940134.html


(本文完)

评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

白马负金羁

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值