tf.train.NewCheckpointReader('path'):path是保存的路径,这个函数可以得到保存的所有变量
例如:
先保存一个模型,参数为v,v1.
import tensorflow as tf;
import numpy as np;
import matplotlib.pyplot as plt;
v = tf.Variable(0, dtype=tf.float32, name='v')
v1 = tf.Variable(0, dtype=tf.float32, name='v1')
result = v + v1
x = tf.placeholder(tf.float32, shape=[1], name='x')
test = result + x
init = tf.initialize_all_variables()
saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(init)
saver.save(sess, "/home/penglu/Desktop/lp/model.ckpt")
利用tf.train.NewCheckpointReader导出所有变量
import tensorflow as tf;
import numpy as np;
import matplotlib.pyplot as plt;
reader = tf.train.NewCheckpointReader("/home/penglu/Desktop/lp/model.ckpt")
variables = reader.get_variable_to_shape_map()
for ele in variables:
print ele
输出:
v1
v