一、Saver保存
import tensorflow as tf
import numpy as np
#定义W和b
W = tf.Variable([[1,2,3],[3,5,6]],dtype = tf.float32,name = 'weight')
b = tf.Variable([1,2,3],dtype = tf.float32,name = 'biases')
#注:初始化变量Variable
init = tf.global_variables_initializer()
#建立tf.train.Saver() 来保存, 提取变量。
#建立my_net文件夹,保存变量
saver = tf.train.Saver()
sess = tf.Session()
sess.run(init)
#保存变量到路径my_net
save_path = saver.save(sess,"my_net/save_net.ckpt")#保存格式为ckpt
#输出保存的变量
print("save path:",save_path)
结果:
二、Saver读取
import tensorflow as tf
import numpy as np
#建立W,b的空容器
W = tf.Variable(np.arange(6).reshape((2, 3)), dtype=tf.float32, name="weights")
b = tf.Variable(np.arange(3).reshape((1, 3)), dtype=tf.float32, name="biases")
#不需要初始化变量
saver = tf.train.Saver()
with tf.Session() as sess:
saver.restore(sess, "my_net/save_net.ckpt")
print("weights:", sess.run(W))
print("biases:", sess.run(b))