Tensorflow2.x callbacks函数分析
前言
你好! 博主在使用tensorflow做深度学习的研究时,发现这一块的内容水太深了,tensorflow目前最新版本是2.3.0,但是我尝试更新到tensorflow2.2.0的时候,遇到了很多问题,无奈又返回到2.1.0。博主一直在使用anconda虚拟环境,因为它可以很方便的管理python环境,可以做到与多个深度学习库共存。本次讨论的是tensorflow2.x版本中使用的 tf.keras.callbacks.xxx,最近从tensorflow1.1x换成了tensorflow2.x。
发现好多内容变化了,没了slim模块,对于之前网络架构使用该模块搭建的人来说,真的是个大坑。不过2.x版本中有引入的很多内容,很大程度上简化了编程的难度,使tensorflow入门的门槛降低了许多。
其中callbacks模块自带的几个函数用着真的很方便,如下:官网链接
常用的函数已经加粗
class BaseLogger:累积指标的时期平均值的回调。
class CSVLogger:将纪元结果流式传输到csv文件的回调。
class Callback:用于建立新回调的抽象基类。(支持自定义callbacks)
class EarlyStopping:当监视的变量停止改善时,停止训练。
class History:将事件记录到History对象中的回调。
class LambdaCallback:用于即时创建简单,自定义回调的回调。
class LearningRateScheduler:学习率调度程序
class ModelCheckpoint:保存模型的时机。
class ProgbarLogger:将指标输出到标准输出的回调。
class ReduceLROnPlateau:当指标停止改善时,降低学习率。
class RemoteMonitor:用于将事件流传输到服务器的回调。
class TensorBoard:为TensorBoard启用可视化,(这个函数真的很强大)。
class TerminateOnNaN:当遇到NaN丢失时回调将终止训练。
callbacks函数总览
为了方便研究callbacks函数的运行机制,下面将callbacks函数的基础架构拿出来,所有的callbacks函数的子类方法都是在这个基础上构建的。
@keras_export('keras.callbacks.Callback')
class Callback(object):
"""Abstract base class used to build new callbacks.
Attributes:
params: dict. Training parameters
(eg. verbosity, batch size, number of epochs...).
model: instance of `keras.models.Model`.
Reference of the model being trained.
validation_data: Deprecated. Do not use.
The `logs` dictionary that callback methods
take as argument will contain keys for quantities relevant to
the current batch or epoch.
Currently, the `.fit()` method of the `Model` class
will include the following quantities in the `logs` that
it passes to its callbacks:
on_epoch_end: logs include `acc` and `loss`, and
optionally include `val_loss`
(if validation is enabled in `fit`), and `val_acc`
(if validation and accuracy monitoring are enabled).
on_batch_begin: logs include `size`,
the number of samples in the current batch.
on_batch_end: logs include `loss`, and optionally `acc`
(if accuracy monitoring is enabled).
"""
def __init__(self):
self.validation_data = None
self.model = None
# Whether this Callback should only run on the chief worker in a
# Multi-Worker setting.
# TODO(omalleyt): Make this attr public once solution is stable.
self._chief_worker_only = None
def set_params(self, params):
self.params = params
def set_model(self, model):
self.model = model
@doc_controls.for_subclass_implementers
def on_batch_begin(self, batch, logs=None):
"""A backwards compatibility alias for `on_train_batch_begin`."""
@doc_controls.for_subclass_implementers
def on_batch_end(self, batch, logs=None):
"""A backwards compatibility alias for `on_train_batch_end`."""
@doc_controls.for_subclass_implementers
def on_epoch_begin(self, epoch, logs=None):
"""Called at the start of an epoch.
"""
@doc_controls.for_subclass_implementers
def on_epoch_end(self, epoch, logs=None):
"""Called at the end of an epoch.
"""
@doc_controls.for_subclass_implementers
def on_train_batch_begin(self, batch, logs=None):
"""Called at the beginning of a training batch in `fit` methods.
"""
# For backwards compatibility.
self.on_batch_begin(batch, logs=logs)
@doc_controls.for_subclass_implementers
def on_train_batch_end