这个例子是用Tensorflow中的Saver进行模型保存与读取:
模型保存:
import tensorflow as tf
W = tf.Variable([[1,2,3],[3,4,5]],dtype=tf.float32,name='weights')
b = tf.Variable([[1,2,3]],dtype=tf.float32,name='biases')
init = tf.initialize_all_variables()
saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(init)
save = saver.save(sess,"Model/model.ckpt")
模型读取:
import tensorflow as tf
import numpy as np
# 模型读取
W = tf.Variable(np.arange(6).reshape((2,3)),dtype=tf.float32,name="weights")
b = tf.Variable(np.arange(3).reshape((1,3)),dtype=tf.float32,name="biases")
# not need init step
saver = tf.train.Saver()
with tf.Session() as sess:
saver.restore(sess,"Model/model.ckpt")
print("weights:",sess.run(W))
print("biases:",sess.run(b))
结果如下: