Python深度学习之Keras回调函数与TensorBoard

Deep Learning with Python

这篇文章是我学习《Deep Learning with Python》(第二版,François Chollet 著) 时写的系列笔记之一。文章的内容是从 Jupyter notebooks 转成 Markdown 的,你可以去 GitHubGitee 找到原始的 .ipynb 笔记本。

你可以去这个网站在线阅读这本书的正版原文(英文)。这本书的作者也给出了配套的 Jupyter notebooks

本文为 第7章 高级的深度学习最佳实践 (Chapter 7. Advanced deep-learning best practices) 的笔记之一。

7.2 Inspecting and monitoring deep-learning models using Keras callbacks and TensorBoard

使用 Keras 回调函数和 TensorBoard 来检查并监控深度学习模型

用 model.fit() 开启一个复杂的训练任务后,我们就只能干等着,在结束前都不知道它有没有正确工作,也无法控制它,好似抛出了一架纸飞机,任它随风去往不确定的远方。比起这样不受控制的纸飞机,或许我们更希望要一台智能的无人机,可以感知环境,将数据发回给我们,并基于当前状态自主航行。 Keras 的回调函数与 TensorBoard 这样的工具就可以帮我们把“纸飞机”改造成“智能的无人机”。

训练中将回调函数作用于模型

我们在训练模型的时候,一开始是不知道要跑多少轮的,我们只能让它跑足够多的轮次,然后手动找出一个最佳的轮次数,重新用这个最佳轮次数去训练模型,这样相当耗时。所以,我们更希望当模型观测到验证损失不再改善时就自动停止训练。

这种操作就可以用 Keras 回调函数(callback)完成:Keras 提供了很多有用的 callback,放在 keras.callbacks 里,自动停止训练只是其中一种用法。

Callback 会在训练过程中的不同时间点被模型调用,它可以访问模型的状态,并可以采取一些行动,例如:

  • 模型检查点:在训练过程中的不同时间点保存模型的当前权重
  • 提前终止:验证损失不再改善时中断训练
  • 动态调节参数值:例如动态调整优化器的学习率
  • 记录训练指标和验证指标:用这些指标就可以将模型学到的表示可视化
使用 callback

Keras 内置了许多有用的 callback,例如:

  • ModelCheckpoint:在训练过程中保存训练到某些状态的模型。可以用来持续不断地保存模型,也可以选择性地保存目前的最佳模型;
  • EarlyStopping:监控的目标指标,如果在设定的轮数内不再改善,则中断训练;
  • ReduceLROnPlateau:在验证损失不再改善时(遇到loss plateau),降低学习率。

这些 callback 的使用也很简单:

from tensorflow import keras

callbacks_list = [
    # 在每轮完成后保存权重
    keras.callbacks.ModelCheckpoint(
        filepath='my_model.h5',  # 保存文件的路径
        monitor='val_loss',      # monitor:要验证的指标
        save_best_only=True,     # 只保存让 monitor 指标最好的模型(如果 monitor 没有改善,就不保存)
    ),
    # 不再改善时中断训练
    keras.callbacks.EarlyStopping(
        monitor='acc',           # 要验证的指标
        patience=10,             # 如果 monitor 在多于 patience 轮内(比如这里就是10+1=11轮)没有改善,则中断训练
    ),
    # 不再改善时降低学习率
    keras.callbacks.ReduceLROnPlateau(
        monitor='val_loss',    # 要验证的指标
        factor=0.1,            # 触发时:学习率 *= factor
        patience=5,            # monitor 在 patience 轮内没有改善,则触发降低学习率
    ),
]

model.compile(optimizer='rmsprop', 
              loss='binary_crossentropy', 
              metrics=['acc'])    # 在 callback 里用到了 acc 做指标,所以这里的 metrics 里要有 acc

model.fit(x, y, 
          epochs=10, 
          batch_size=32, 
 
  • 2
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值