tf.train.NewCheckpointReader

保存训练模型的时候不仅持久化了计算图结构,也持久化了变量的取值。
TensorFlow提供的tf.train.NewCheckpointReader类来查看保存的变量的信息。

import tensorflow as tf

v1 = tf.Variable(tf.constant(1.0,tf.float32, [1]),name='v1')
v2 = tf.Variable(tf.constant(2.0,tf.float32, [1]),name='v2')

result = v1 + v2
saver = tf.train.Saver()

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print(result.eval())
    saver.save(sess,'./model/model.ckpt')

在这里插入图片描述


import tensorflow as tf

# tf.train.NewCheckpointReader可以读取checkpoint文件中保存的所有变量
reader = tf.train.NewCheckpointReader('./model/model.ckpt')

# 获取所有变量
print(reader.debug_string().decode("utf-8"))
# 输出  v1 (DT_FLOAT) [1]
#      v2 (DT_FLOAT) [1]

# 获取所有变量列表,是一个从变量名到变量维度的字典
global_variables = reader.get_variable_to_shape_map()
for variable_name in global_variables:
    # variable_name为变量名称,global_variables[variable_name]为变量维度
    print(variable_name, global_variables[variable_name])
     # 输出: v2 [1]
     #      v1 [1]
print(reader.get_tensor('v1'))  # 输出: [1.]
print(reader.get_tensor('v2'))  # 输出: [2.]

# 获取所有变量列表,是一个从变量名到变量数据类型的字典
global_variables1 = reader.get_variable_to_dtype_map()
for variable_name in global_variables1:
    print(variable_name, global_variables1[variable_name])
     # 输出: v2 <dtype: 'float32'>
     #      v1 <dtype: 'float32'>
print(reader.get_tensor('v1'))  # 输出: [1.]
print(reader.get_tensor('v2'))  # 输出: [2.]
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值