前言
在辛辛苦苦跑了几个小时甚至几天之后,你训练出了几十万个或者更多的参数,那么你肯定不想只使用这些参数仅仅一次,那么就涉及到这些参数的保存以及提取,幸运的是,tensorflow已经帮我们集成好了相关函数,就是接下来要介绍的tf.train.Saver() 类。
tf.train.Saver()
一 . 用于保存权重和偏重(参数)
在使用之前要先实例化一个类,例如以下代码:
saver = tf.train.Saver()
如何保存?
import tensorflow as tf
import numpy as np
# Save to file
#remember to define the same dtype and shape when restore
W = tf.Variable([[1,2,3],[3,4,5]], dtype=tf.float32, name='weights')
b = tf.Variable([[1,2,3]], dtype=tf.float32, name='biases')
# init= tf.initialize_all_variables() # tf 马上就要废弃这种写法
# 替换成下面的写法:
init = tf.global_variables_initializer()
saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(init)
save_path = saver.save(sess, "my_net/save_net.ckpt")
#
print("Save to path: ", save_path)
这里Saver()类有一个save方法,其参数为(会话名称,要保存的文件路径以及具体文件)保存后的结果如下:
这里会自动生成一个“checkpoint”文件以及其他几个.ckpt文件,用来存储参数。
保存了以后如何提取,或者说读取参数?
见以下代码:
w = tf.Variable(np.arange(6).reshape((2, 3)), dtype=tf.float32, name="weights")
b = tf.Variable(np.arange(3).reshape((1, 3)), dtype=tf.float32, name="biases")
# 这里不需要初始化步骤 init= tf.initialize_all_variables()
saver = tf.train.Saver()
with tf.Session() as sess:
# 提取变量
saver.restore(sess, "my_net/save_net.ckpt")
print("weights:", sess.run(w))
print("biases:", sess.run(b))
这里Saver() 提供了一个restore方法,其参数为(会话,需要提取的文件)
这里有几点需要说明一下:
- Saver() 类只能存储和提取神经网络的参数,现在还不能存储整个网络架构,这个比较操蛋(不过我相信以后肯定会出现类似的存储整个训练好的架构的函数),现如今想要使用已经训练好的参数,还是需要重新定义一个一模一样的参数变量,无论是在数据类型上,还是shape上