保存训练模型的时候不仅持久化了计算图结构,也持久化了变量的取值。
TensorFlow提供的tf.train.NewCheckpointReader
类来查看保存的变量的信息。
import tensorflow as tf
v1 = tf.Variable(tf.constant(1.0,tf.float32, [1]),name='v1')
v2 = tf.Variable(tf.constant(2.0,tf.float32, [1]),name='v2')
result = v1 + v2
saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print(result.eval())
saver.save(sess,'./model/model.ckpt')
import tensorflow as tf
# tf.train.NewCheckpointReader可以读取checkpoint文件中保存的所有变量
reader = tf.train.NewCheckpointReader('./model/model.ckpt')
# 获取所有变量
print(reader.debug_string().decode("utf-8"))
# 输出 v1 (DT_FLOAT) [1]
# v2 (DT_FLOAT) [1]
# 获取所有变量列表,是一个从变量名到变量维度的字典
global_variables = reader.get_variable_to_shape_map()
for variable_name in global_variables:
# variable_name为变量名称,global_variables[variable_name]为变量维度
print(variable_name, global_variables[variable_name])
# 输出: v2 [1]
# v1 [1]
print(reader.get_tensor('v1')) # 输出: [1.]
print(reader.get_tensor('v2')) # 输出: [2.]
# 获取所有变量列表,是一个从变量名到变量数据类型的字典
global_variables1 = reader.get_variable_to_dtype_map()
for variable_name in global_variables1:
print(variable_name, global_variables1[variable_name])
# 输出: v2 <dtype: 'float32'>
# v1 <dtype: 'float32'>
print(reader.get_tensor('v1')) # 输出: [1.]
print(reader.get_tensor('v2')) # 输出: [2.]