比如用下面代码保存一个模型:
import tensorflow as tf
import numpy as np
input = np.random.random([10, 2])
print(input)
output = tf.layers.dense(input, units=2)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print(sess.run(output))
for n in tf.global_variables():
print(n)
print(sess.run(n))
saver = tf.train.Saver()
saver.save(sess, "models/test")
使用tf.train.init_from_checkpoint将模型恢复出来:
import tensorflow as tf
import numpy as np
input = tf.placeholder(dtype=tf.float64, shape=[None, 2])
with tf.variable_scope("inference"):
output = tf.layers.dense(input, units=2)
output = tf.layers.dense(output, 1, use_bias=False)
assignment_map = {"dense/kernel": "inference/dense/kernel"}
assignment_map["dense/bias"] = "inference/dense/bias"
with tf.Session() as sess:
tf.train.init_from_checkpoint("models", assignment_map)
sess.run(tf.global_variables_initializer())
for n in tf.global_variables():
print(n)
print(sess.run(n))