tensorflow tf.train.Saver()在网络模型参数保存以及提取时的用法

前言

在辛辛苦苦跑了几个小时甚至几天之后,你训练出了几十万个或者更多的参数,那么你肯定不想只使用这些参数仅仅一次,那么就涉及到这些参数的保存以及提取,幸运的是,tensorflow已经帮我们集成好了相关函数,就是接下来要介绍的tf.train.Saver() 类。

tf.train.Saver()

一 . 用于保存权重和偏重(参数)
在使用之前要先实例化一个类,例如以下代码:

saver = tf.train.Saver()

如何保存?

import tensorflow as tf
import numpy as np

# Save to file
#remember to define the same dtype and shape when restore
W = tf.Variable([[1,2,3],[3,4,5]], dtype=tf.float32, name='weights')
b = tf.Variable([[1,2,3]], dtype=tf.float32, name='biases')

# init= tf.initialize_all_variables() # tf 马上就要废弃这种写法
# 替换成下面的写法:
init = tf.global_variables_initializer()

saver = tf.train.Saver()

with tf.Session() as sess:
    sess.run(init)
    save_path = saver.save(sess, "my_net/save_net.ckpt")
    #
    print("Save to path: ", save_path)

这里Saver()类有一个save方法,其参数为(会话名称,要保存的文件路径以及具体文件)保存后的结果如下:
在这里插入图片描述
这里会自动生成一个“checkpoint”文件以及其他几个.ckpt文件,用来存储参数。

保存了以后如何提取,或者说读取参数?
见以下代码:

w = tf.Variable(np.arange(6).reshape((2, 3)), dtype=tf.float32, name="weights")
b = tf.Variable(np.arange(3).reshape((1, 3)), dtype=tf.float32, name="biases")

# 这里不需要初始化步骤 init= tf.initialize_all_variables()

saver = tf.train.Saver()
with tf.Session() as sess:
    # 提取变量
    saver.restore(sess, "my_net/save_net.ckpt")
    print("weights:", sess.run(w))
    print("biases:", sess.run(b))

这里Saver() 提供了一个restore方法,其参数为(会话,需要提取的文件)

这里有几点需要说明一下:

  1. Saver() 类只能存储和提取神经网络的参数,现在还不能存储整个网络架构,这个比较操蛋(不过我相信以后肯定会出现类似的存储整个训练好的架构的函数),现如今想要使用已经训练好的参数,还是需要重新定义一个一模一样的参数变量,无论是在数据类型上,还是shape上
  • 1
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值