一 实例
将模型里的内容打印出来,同时演示将指定的内容打印出来。
二 代码
import tensorflow as tf
from tensorflow.python.tools.inspect_checkpoint import print_tensors_in_checkpoint_file
savedir = "log/"
print_tensors_in_checkpoint_file(savedir+"linermodel.cpkt", None, True)
W = tf.Variable(1.0, name="weight")
b = tf.Variable(2.0, name="bias")
# 放到一个字典里:
saver = tf.train.Saver({'weight': b, 'bias': W})
with tf.Session() as sess:
tf.global_variables_initializer().run()
saver.save(sess, savedir+"linermodel.cpkt")
print_tensors_in_checkpoint_file(savedir+"linermodel.cpkt", None, True)
三 运行结果
tensor_name: bias
[ 0.06552324]
tensor_name: weight
[ 2.04334879]
tensor_name: bias
1.0
tensor_name: weight
2.0
四 运行说明
可以看到,tensor_name:后面跟的就是创建的变量名,接着是它的数值。
tf.train.Saver函数里还可以放参数来实现更高级的功能,可以指定存储变量与变量的对应关系。
例子中,W的值设置为1.0,b的值设置为2.0。在创建saver时,将它们颠倒,保存的模型打印出来之后可以看到,bias变成了1.0,而weight变成了2.0。