查看TensorFlow已训模型的结构和网络参数

文章目录

概要

通过以下实例,你将学会如何查看神经网络结构并打印出训练参数。

流程

  • 准备一个简易的二分类数据集,并编写一个单层的神经网络
train_data = np.array([[1, 2, 3, 4, 5], 
                       [7, 7, 2, 4, 10], 
                       [1, 9, 3, 6, 5], 
                       [6, 7, 8, 9, 10]])

train_label = np.array([1, 0, 1, 0])  #标签与样本一一对齐


""" 定义一个单层的神经网络 """
model = tf.keras.models.Sequential([
    tf.keras.layers.Dense(1, activation=None)
])
  • 编译,训练,并保存模型
model.compile(
    loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
    optimizer='adam'
)
model.fit(train_data,
          train_label,
          epochs=2750)

tf.saved_model.save(model, "model_dir")  #保存到当前目录中,目录名为model_dir
  • 模型保存形式

模型节点和矩阵参数集中保存在 .data-00000-of-00001和 .index文件中,利用这两个文件中创建CheckpointReader对象。

  • 利用模型的Checkpoint对象查看模型结构和参数

Checkpoint对象存储了模型中所有可tracable追踪的对象,并记录保存着这些对象的参数及名称。可通过 tf.train.load_checkpoint()方法获得一个CheckpointReader对象,该对象可以读取Checkpoint内的所有信息。

"""  最后的variables是.data-00000-of-00001和 .index文件去掉后缀后的表达形式,
     从而统一代表着这两个文件"""
save_path = './model_dir/variables/variables'  # 

reader = tf.train.load_checkpoint(save_path)  # 得到CheckpointReader

"""  打印Checkpoint中存储的所有参数名和参数shape """
for variable_name, variable_shape in reader.get_variable_to_shape_map().items():
    print(f'{variable_name} : {variable_shape}')
 

optimizer/_variables/2/.ATTRIBUTES/VARIABLE_VALUE : [5, 1]
optimizer/_iterations/.ATTRIBUTES/VARIABLE_VALUE : []
_CHECKPOINTABLE_OBJECT_GRAPH : []
keras_api/metrics/0/count/.ATTRIBUTES/VARIABLE_VALUE : []
keras_api/metrics/0/total/.ATTRIBUTES/VARIABLE_VALUE : []
layer_with_weights-0/bias/.ATTRIBUTES/VARIABLE_VALUE : [1]
layer_with_weights-0/kernel/.ATTRIBUTES/VARIABLE_VALUE : [5, 1]
optimizer/_variables/1/.ATTRIBUTES/VARIABLE_VALUE : [5, 1]
optimizer/_learning_rate/.ATTRIBUTES/VARIABLE_VALUE : []
optimizer/_variables/3/.ATTRIBUTES/VARIABLE_VALUE : [1]
optimizer/_variables/4/.ATTRIBUTES/VARIABLE_VALUE : [1]

其中Dense层的权重参数和偏差bias的显示信息为,

layer_with_weights-0/bias/.ATTRIBUTES/VARIABLE_VALUE : [1]
layer_with_weights-0/kernel/.ATTRIBUTES/VARIABLE_VALUE : [5, 1]

接着利用刚刚打印出的参数名即可查看其参数值,

print(reader.get_tensor('layer_with_weights-0/kernel/.ATTRIBUTES/VARIABLE_VALUE'))
print(reader.get_tensor("layer_with_weights-0/bias/.ATTRIBUTES/VARIABLE_VALUE"))


[[-1.7741445 ]
 [-0.07314294]
 [-0.07213379]
 [ 1.1694099 ]
 [-0.36803177]]

[1.7487208]

  • 验证
model = tf.saved_model.load('model_dir')
print(model([[1, 2, 3, 4, 5]]))
output = -1.7741445 - 2*0.07314294 - 3*0.07213379 + 4*1.1694099 - 5*0.36803177+1.7487208
print(output)


tf.Tensor([[2.4493697]], shape=(1, 1), dtype=float32)

2.4493698000000004

  • 6
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值