原文链接: TensorFlow 获取保存模型的参数值
上一篇: TensorFlow 线性回归 拟合
tf的所有操作必须使用run才能生效
所以只有W的值被改变了,b的值依然是0
保存模型
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
# 模型参数
W = tf.Variable(tf.random_normal([1]), name="weight")
b = tf.Variable(tf.zeros([1]), name="bias")
init = tf.global_variables_initializer()
saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(init)
# 所有操作必须使用run才能生效
sess.run(tf.assign(W, tf.constant(10., shape=[1])))
tf.assign(b, tf.constant(10., shape=[1]))
saver.save(sess, 'net/my_net.ckpt')
模型参数值的获取
from tensorflow.python.tools.inspect_checkpoint import print_tensors_in_checkpoint_file
save_dir = 'net/'
print_tensors_in_checkpoint_file(save_dir + 'my_net.ckpt', None, True)
tensor_name: bias
[0.]
tensor_name: weight
[10.]