基于tf 1.12.0版本
保存模型:
import tensorflow as tf
import numpy as np
## 保存模型
W = tf.Variable([[1,2,3],[3,2,1]],dtype=tf.float32, name="weights")
b = tf.Variable([[7,7,7]], dtype=tf.float32, name="biases")
saver = tf.train.Saver()
with tf.Session() as sess:
# 重要:需要初始化变量
sess.run(tf.global_variables_initializer())
save_path = saver.save(sess, "model_save/model.ckpt") # 保存之后返回 保存的路径
print("save to path:", save_path)
save to path: model_save/model.ckpt
$ ll model_save/
total 16
-rw-r--r-- 1 root root 77 Jan 13 10:49 checkpoint
-rw-r--r-- 1 root root 36 Jan 13 10:49 model.ckpt.data-00000-of-00001
-rw-r--r-- 1 root root 161 Jan 13 10:49 model.ckpt.index
-rw-r--r-- 1 root root 3440 Jan 13 10:49 model.ckpt.meta
导入模型:
import tensorflow as tf
import numpy as np
## 导入模型:
## 因为保存的只是参数,所以要先建立结构一样的模型
W = tf.Variable(np.arange(6).reshape((2,3)), dtype=tf.float32, name="weights")
b = tf.Variable(np.arange(3).reshape((1,3)), dtype=tf.float32, name="biases")
# 不需要初始化这一步
saver = tf.train.Saver()
with tf.Session() as sess:
saver.restore(sess, "model_save/model.ckpt")
print("Weights:", sess.run(W))
print("biases:", sess.run(b))
Weights: [[1. 2. 3.]
[3. 2. 1.]]
biases: [[7. 7. 7.]]