Tensorflow2.x 训练网络时的指标输出,以及模型结构图导出

训练指标输出

1. 使用TensorBoard
2. 使用History类

TensorBorad

TensorBoard的Scalars可以可视化这些指标

使用步骤:

记录训练中的指标,需要执行以下操作:

  1. 创建KerasTensorBoard回调
  2. 指定日志目录
  3. 将TensorBoard回调传递Keras的Model.fit()

回调函数:

tf.keras.callbacks.TensorBoard(
           log_dir='logs', histogram_freq=0, write_graph=True, write_images=False,
          update_freq='epoch', profile_batch=2, embeddings_freq=0,
    embeddings_metadata=None, **kwargs
)

参数

  1. log_dir:将要由TensorBoard解析的日志文件保存到的目录路径。
  2. histogram_freq:计算模型各层的激活度和权重直方图的频率。如果设置为0,将不计算直方图。必须为直方图可视化指定验证数据。
  3. write_graph:是否在TensorBoard中可视化图形。当write_graph设置为True时,日志文件可能会变得很大。
  4. write_images:是否编写模型权重以在TensorBoard中可视化为图像。
  5. update_freq:‘batch’或’epoch’或整数。使用时’batch’,每批之后将损失和指标写入TensorBoard。同样适用于’epoch’。如果使用整数,假设1000,回调将每1000批将指标和损失写入TensorBoard。请注意,过于频繁地向TensorBoard写入可能会减慢训练速度。
  6. profile_batch:分析批次以采样计算特征。默认情况下,它将配置第二批。将profile_batch = 0设置为禁用分析。必须在TensorFlow急切模式下运行。
  7. embeddings_freq:嵌入层可视化的频率(以历元计)。如果设置为0,则嵌入将不可见。
  8. embeddings_metadata:将层名称映射到文件名的字典,该嵌入层的元数据保存在该文件名中。查看 有关元数据文件格式的 详细信息。如果相同的元数据文件用于所有嵌入层,则可以传递字符串。

定义好回调函数后,在fit()函数中加入参数
如下:

 logdir = "logs/" + datetime.now().strftime("%Y%m%d-%H%M%S")
 tensorboard_callback = keras.callbacks.TensorBoard(log_dir=logdir)
.....(省略掉的代码)
history = model.fit(train_data,epochs=5,validation_data=test_data,validation_freq=1,
             callbacks=[tensorboard_callback]
  )

然后在终端,使用 tensorboard --logdir log/,就会出现下面的信息:

Serving TensorBoard on localhost; to expose to the network, use a proxy or pass --bind_all
TensorBoard 2.1.0 at http://localhost:6007/ (Press CTRL+C to quit)

进入连接即可出现网页,就会显示图表。
模型结构图
在这里插入图片描述

使用History类

这种方式比较简单,

history = model.fit(train_data,epochs=5,validation_data=test_data,validation_freq=1,
             callbacks=[tensorboard_callback]

fit() 会返回一个History的类,它的History.history属性记录了训练时期(每个epoch),训练损失和准确率以及验证损失和验证准确率。
如下:

model.compile(optimizer=keras.optimizers.Adam(lr=0.01),
        loss = tf.losses.SparseCategoricalCrossentropy(from_logits=True),
        metrics=['accuracy']
)
history = model.fit(train_data,epochs=5,validation_data=test_data,validation_freq=1,
                    # callbacks=[tensorboard_callback]
                    )
plt.plot(history.history['accuracy'], label='accuracy')
plt.plot(history.history['loss'], label='loss')
plt.plot(history.history['val_loss'], label = 'val_loss')
plt.plot(history.history['val_accuracy'], label = 'val_accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
# plt.ylim([0.5, 1])
plt.legend(loc='lower right')
plt.show()

在这里插入图片描述

  • 1
    点赞
  • 24
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

FlyDremever

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值