Tensorflow2.x tf.keras.callbacks函数分析

本文介绍了Tensorflow2.x中tf.keras.callbacks模块的使用,特别是TensorBoard回调函数的详细参数说明和set_model()方法。通过自定义回调函数,帮助理解回调在训练过程中的作用和调用顺序。
摘要由CSDN通过智能技术生成

前言

你好! 博主在使用tensorflow做深度学习的研究时,发现这一块的内容水太深了,tensorflow目前最新版本是2.3.0,但是我尝试更新到tensorflow2.2.0的时候,遇到了很多问题,无奈又返回到2.1.0。博主一直在使用anconda虚拟环境,因为它可以很方便的管理python环境,可以做到与多个深度学习库共存。本次讨论的是tensorflow2.x版本中使用的 tf.keras.callbacks.xxx,最近从tensorflow1.1x换成了tensorflow2.x。

发现好多内容变化了,没了slim模块,对于之前网络架构使用该模块搭建的人来说,真的是个大坑。不过2.x版本中有引入的很多内容,很大程度上简化了编程的难度,使tensorflow入门的门槛降低了许多。

其中callbacks模块自带的几个函数用着真的很方便,如下:官网链接
          常用的函数已经加粗

class BaseLogger:累积指标的时期平均值的回调。

class CSVLogger:将纪元结果流式传输到csv文件的回调。

class Callback:用于建立新回调的抽象基类。(支持自定义callbacks)

class EarlyStopping:当监视的变量停止改善时,停止训练。

class History:将事件记录到History对象中的回调。

class LambdaCallback:用于即时创建简单,自定义回调的回调。

class LearningRateScheduler:学习率调度程序

class ModelCheckpoint:保存模型的时机。

class ProgbarLogger:将指标输出到标准输出的回调。

class ReduceLROnPlateau:当指标停止改善时,降低学习率。

class RemoteMonitor:用于将事件流传输到服务器的回调。

class TensorBoard:为TensorBoard启用可视化,(这个函数真的很强大)

class TerminateOnNaN:当遇到NaN丢失时回调将终止训练。

callbacks函数总览

为了方便研究callbacks函数的运行机制,下面将callbacks函数的基础架构拿出来,所有的callbacks函数的子类方法都是在这个基础上构建的。

@keras_export('keras.callbacks.Callback')
class Callback(object):
  """Abstract base class used to build new callbacks.

  Attributes:
      params: dict. Training parameters
          (eg. verbosity, batch size, number of epochs...).
      model: instance of `keras.models.Model`.
          Reference of the model being trained.
      validation_data: Deprecated. Do not use.

  The `logs` dictionary that callback methods
  take as argument will contain keys for quantities relevant to
  the current batch or epoch.

  Currently, the `.fit()` method of the `Model` class
  will include the following quantities in the `logs` that
  it passes to its callbacks:

      on_epoch_end: logs include `acc` and `loss`, and
          optionally include `val_loss`
          (if validation is enabled in `fit`), and `val_acc`
          (if validation and accuracy monitoring are enabled).
      on_batch_begin: logs include `size`,
          the number of samples in the current batch.
      on_batch_end: logs include `loss`, and optionally `acc`
          (if accuracy monitoring is enabled).
  """

  def __init__(self):
    self.validation_data = None
    self.model = None
    # Whether this Callback should only run on the chief worker in a
    # Multi-Worker setting.
    # TODO(omalleyt): Make this attr public once solution is stable.
    self._chief_worker_only = None

  def set_params(self, params):
    self.params = params

  def set_model(self, model):
    self.model = model

  @doc_controls.for_subclass_implementers
  def on_batch_begin(self, batch, logs=None):
    """A backwards compatibility alias for `on_train_batch_begin`."""

  @doc_controls.for_subclass_implementers
  def on_batch_end(self, batch, logs=None):
    """A backwards compatibility alias for `on_train_batch_end`."""

  @doc_controls.for_subclass_implementers
  def on_epoch_begin(self, epoch, logs=None):
    """Called at the start of an epoch.
    """

  @doc_controls.for_subclass_implementers
  def on_epoch_end(self, epoch, logs=None):
    """Called at the end of an epoch.
    """

  @doc_controls.for_subclass_implementers
  def on_train_batch_begin(self, batch, logs=None):
    """Called at the beginning of a training batch in `fit` methods.
    """
    # For backwards compatibility.
    self.on_batch_begin(batch, logs=logs)

  @doc_controls.for_subclass_implementers
  def on_train_batch_end
  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

两只蜡笔的小新

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值