TensorFlow之保存读取
一.保存
直接上代码:
import tensorflow as tf
import numpy as np
#Save to file
# remember to define the same dtype and shape wher restore
W = tf.Variable([[1,2,3],[3,4,5]],dtype = tf.float32,name = 'weights')
b = tf.Variable([[1,2,3]],dtype = tf.float32,name='biases')
init = tf.global_variables_initializer()
saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(init)
save_path = saver.save(sess,"my_net/save_net.ckpt")
print("Save to path:",save_path)
注意要先在当前目录里创建文件夹.
运行结果:
二.恢复
################################################
# restore variables
# redefine the same shape and same type for your variables
W = tf.Variable(np.arange(6).reshape((2, 3)), dtype=tf.float32, name="weights")
b = tf.Variable(np.arange(3).reshape((1, 3)), dtype=tf.float32, name="biases")
# not need init step
#目前来说只能保存Variable,不能保存整个神经网络
saver = tf.train.Saver()
with tf.Session() as sess:
saver.restore(sess, "my_net/save_net.ckpt")
print("weights:", sess.run(W))
print("biases:", sess.run(b))
运行结果:
成功读取上面保存的值.