神经网络训练中回调函数的使用

一、什么是回调函数

回调函数是神经网络训练的重要组成部分

     回调函数是在特定事件发生时被调用的函数。在机器学习中,回调函数通常与模型的训练过程相关联,并在训练的不同阶段触发执行特定的操作。回调函数提供了一种可以自定义和控制模型训练过程的方式,可以根据需要在训练过程中添加额外的功能或逻辑。

     回调操作可以在训练的各个阶段执行,可能是在epoch之间,在处理一个batch之后,甚至在满足某个条件的情况下。回调可以利用许多创造性的方法来改进训练和性能,节省计算资源,并提供有关神经网络内部发生的事情的结论。

 二、常见的回调函数

1、早停法(EarlyStopping)

早停法是一种常用的回调函数,用于在训练过程中监控模型的验证集损失(或其他指标),并在损失不再改善时停止训练,以防止过拟合,可以非常有助于防止在训练模型时产生额外的冗余运行。冗余运行会导致高昂的计算成本。当网络在给定的时间段内没有得到改善时,网络完成训练并停止使用计算资源。早停法根据设置的条件来判断是否停止训练。常见的条件包括:

  • monitor:需要监控的指标,例如验证集的损失('val_loss')或准确率('val_accuracy')。
  • patience:当指标在多个训练周期内没有改善时,停止训练。即连续patience个周期中,指标没有改善,则停止训练。
  • mode:指标的改善判断模式,可以是最小化('min')或最大化('max')。例如,如果monitor='val_loss'mode='min',则当验证集损失不再减小时,停止训练。
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
from tensorflow.keras.callbacks import EarlyStopping

# 创建模型
model = Sequential()
model.add(Dense(64, activation='relu', input_dim=10))
model.add(Dense(64, activation='relu'))
model.add(Dense(1, activation='sigmoid'))

# 编译模型
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])

# 加载数据
# x_train 和 y_train 是训练数据的特征和标签

# 定义早停法回调函数
early_stopping = EarlyStopping(monitor='val_loss', patience=5, mode='min')

# 训练模型
model.fit(x_train, y_train, validation_split=0.2, epochs=100, batch_size=32, callbacks=[early_stopping])

在上述代码中,我们首先创建了一个序列模型,然后通过添加全连接层构建了一个简单的神经网络模型。接下来,我们使用compile方法来编译模型,指定了优化器、损失函数和评估指标。

在训练之前,我们定义了早停法回调函数EarlyStopping。在这个例子中,我们设置了监控指标为验证集损失(monitor='val_loss'),当验证集损失在连续5个周期内没有改善时,训练将会提前停止(patience=5)。这里的mode='min'表示我们希望最小化验证集损失。

最后,我们在训练模型时传入了回调函数列表callbacks=[early_stopping],使得早停法回调函数在每个训练周期结束时被调用。这样,当满足早停法的条件时,训练将会自动停止。

 2、学习率衰减(ReduceLROnPlateau)

学习率衰减是一种回调函数,用于在训练过程中动态地调整学习率。学习率衰减的目的是在训练的早期使用较大的学习率以快速收敛,而在训练的后期逐渐降低学习率,使模型更加稳定并有更好的收敛性能。学习率衰减的主要参数包括:

  • monitor:需要监控的指标,通常是验证集的损失('val_loss')或验证集的准确率('val_accuracy')。
  • factor:学习率每次衰减的因子,通常在0和1之间。
  • patience:当指标在多个训练周期内没有改善时,减小学习率。即连续patience个周期中,指标没有改善,则减小学习率。
  • min_lr:学习率的下限,当学习率衰减到该值以下时,停止衰减。
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
from tensorflow.keras.callbacks import ReduceLROnPlateau

# 创建模型
model = Sequential()
model.add(Dense(64, activation='relu', input_dim=10))
model.add(Dense(64, activation='relu'))
model.add(Dense(1, activation='sigmoid'))

# 编译模型
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])

# 加载数据
# x_train 和 y_train 是训练数据的特征和标签

# 定义学习率衰减回调函数
lr_decay = ReduceLROnPlateau(monitor='val_loss', factor=0.8, patience=3, min_lr=0.00001)

# 训练模型
model.fit(x_train, y_train, validation_split=0.2, epochs=100, batch_size=32, callbacks=[lr_decay])

在上述代码中,我们创建了一个序列模型,并通过添加全连接层构建了一个简单的神经网络模型。然后,我们使用compile方法来编译模型,指定了优化器、损失函数和评估指标。

接下来,我们定义了一个学习率衰减回调函数ReduceLROnPlateau。在这个例子中,我们设置了监控指标为验证集损失(monitor='val_loss'),当验证集损失在连续3个周期内没有改善时,学习率将会以指定的因子(factor=0.8)进行衰减。衰减过程会在每次触发时发生,直到学习率达到最小值(min_lr=0.00001)。

最后,我们在训练模型时传入了回调函数列表callbacks=[lr_decay],使得学习率衰减回调函数在每个训练周期结束时被调用,并相应地更新学习率。

 3、NaN终止(TerminateOnNaN)

TerminateOnNaN是一个Keras回调函数,用于在训练过程中检测并终止训练过程中出现的Na(Not a Number)值。

在深度学习中,NaN值通常表示数值计算溢出或发散的情况。当模型出现NaN值时,它表明模型的参数或梯度可能已经变得不稳定,无法继续有效地进行训练。为了避免继续训练出现无效的模型,可以使用TerminateOnNaN回调函数来监测并终止训练。

TerminateOnNaN有助于防止在训练中产生梯度爆炸问题,因为输入NaN会导致网络的其他部分发生爆炸。如果不采用TerminateOnNaN,Keras并不阻止网络的训练。另外,nan会导致对计算能力的需求增加。为了防止这些情况发生,添加TerminateOnNaN是一个很好的安全检查。

当使用TerminateOnNaN回调函数时,在每个训练周期结束后,回调函数将检查模型的损失值是否为NaN。如果损失值为NaN,则训练过程将被立即终止,并引发一个异常来停止训练。

import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
from tensorflow.keras.callbacks import TerminateOnNaN

# 创建模型
model = Sequential()
model.add(Dense(64, activation='relu', input_dim=10))
model.add(Dense(64, activation='relu'))
model.add(Dense(1, activation='sigmoid'))

# 编译模型
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])

# 加载数据
# x_train 和 y_train 是训练数据的特征和标签

# 定义TerminateOnNaN回调函数
nan_termination = TerminateOnNaN()

# 训练模型
model.fit(x_train, y_train, epochs=100, batch_size=32, callbacks=[nan_termination])

在上述代码中,我们创建了一个序列模型,并通过添加全连接层构建了一个简单的神经网络模型。然后,我们使用compile方法来编译模型,指定了优化器、损失函数和评估指标。

接下来,我们定义了一个TerminateOnNaN回调函数nan_termination

最后,我们在训练模型时传入了回调函数列表callbacks=[nan_termination],使得TerminateOnNaN回调函数在每个训练周期结束时被调用。如果训练过程中出现NaN值,训练将会被立即终止,并引发一个异常。

通过使用TerminateOnNaN回调函数,可以帮助我们识别并及时处理训练过程中的数值不稳定问题,确保训练过程的有效性和稳定性。

4、模型权重保存(ModelCheckpoint)

ModelCheckpoint是一个Keras回调函数,用于在训练过程中保存模型的权重。

在深度学习中,模型的训练过程通常需要进行多个训练周期(epochs),而每个周期结束后的模型状态可能是不同的。为了保存在训练过程中得到的最佳模型或定期保存模型的权重,可以使用ModelCheckpoint回调函数。

ModelCheckpoint回调函数允许您定义保存模型权重的条件和设置。以下是一些常用的参数:

  • filepath:表示模型权重保存的文件路径。可以使用格式化字符串指定文件名,例如weights.{epoch:02d}-{val_loss:.2f}.h5,其中{epoch:02d}表示当前训练周期的数字,{val_loss:.2f}表示验证集损失的浮点数值。
  • monitor:表示要监测的指标,如val_lossval_accuracy。当这个指标在训练过程中有改善时,模型的权重将被保存。
  • save_best_only:一个布尔值,表示是否仅保存在验证集上性能最好的模型权重。如果设置为True,则只有当监测的指标有改善时才会保存模型权重。
  • save_weights_only:一个布尔值,表示是否仅保存模型的权重而不保存模型的结构。如果设置为True,则只保存权重,否则将保存整个模型(包括结构和权重)。
  • save_freq:表示保存模型权重的频率。可以设置为整数表示每隔多少个样本保存一次,或者设置为字符串'epoch'表示每个训练周期结束后保存一次。

 

import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
from tensorflow.keras.callbacks import ModelCheckpoint

# 创建模型
model = Sequential()
model.add(Dense(64, activation='relu', input_dim=10))
model.add(Dense(64, activation='relu'))
model.add(Dense(1, activation='sigmoid'))

# 编译模型
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])

# 加载数据
# x_train 和 y_train 是训练数据的特征和标签

# 定义ModelCheckpoint回调函数
checkpoint = ModelCheckpoint(filepath='weights.{epoch:02d}-{val_loss:.2f}.h5',
                             monitor='val_loss',
                             save_best_only=True,
                             save_weights_only=False,
                             save_freq='epoch')

# 训练模型
model.fit(x_train, y_train, validation_split=0.2, epochs=100, batch_size=32, callbacks=[checkpoint])

在上述代码中,我们创建了一个序列模型,并通过添加全连接层构建了一个简单的神经网络模型。然后,我们使用compile方法来编译模型,指定了优化器、损失函数和评估指标。

接下来,我们定义了一个ModelCheckpoint回调函数checkpoint。我们指定了保存模型权重的文件路径(filepath),监测指标为验证集损失(monitor='val_loss'),仅保存在验证集上性能最好的模型权重(save_best_only=True),保存整个模型而不仅仅是权重(save_weights_only=False),并在每个训练周期结束后保存模型权重(save_freq='epoch')。

最后,我们在训练模型时传入了回调函数列表callbacks=[checkpoint],使得ModelCheckpoint回调函数在每个训练周期结束时被调用,并根据设定的条件保存模型权重。

通过使用ModelCheckpoint回调函数,可以方便地保存模型的权重,并在训练过程中选择性地保存最佳模型或定期保存模型的权重,以备后续使用。

三、总结

通过使用回调函数,可以在模型训练过程中实现更多的灵活性和自定义功能。例如,可以根据验证集的损失来选择最佳的模型参数,或者在训练过程中动态调整学习率以提高模型的收敛速度和性能。回调函数使得训练过程更加可控和高效,并提供了一种灵活的方式来处理各种训练中的需求和场景。

但是编写自定义回调是Keras包含的最好的特性之一,它允许执行高度特定的操作。但是,请注意,构造它比使用默认回调要复杂得多。

我们的自定义回调将采用类的形式。类似于在PyTorch中构建神经网络,我们可以继承keras.callbacks.Callback回调,它是一个基类。

我们的类可以有许多函数,这些函数必须具有下面列出的给定名称以及这些函数将在何时运行。例如,将在每个epoch开始时运行on_epoch_begin函数。下面是Keras将从自定义回调中读取的所有函数,但是可以添加其他“helper”函数。

class CustomCallback(keras.callbacks.Callback): #继承keras的基类
    def on_train_begin(self, logs=None):
        #日志是某些度量的字典,例如键可以是 ['loss', 'mean_absolute_error']
    def on_train_end(self, logs=None): ...
    def on_epoch_begin(self, epoch, logs=None): ...
    def on_epoch_end(self, epoch, logs=None): ...
    def on_test_begin(self, logs=None): ...
    def on_test_end(self, logs=None): ...
    def on_predict_begin(self, logs=None): ...
    def on_predict_end(self, logs=None): ...
    def on_train_batch_begin(self, batch, logs=None): ...
    def on_train_batch_end(self, batch, logs=None): ...
    def on_test_batch_begin(self, batch, logs=None): ...
    def on_test_batch_end(self, batch, logs=None): ...
    def on_predict_batch_begin(self, batch, logs=None): ...
    def on_predict_batch_end(self, batch, logs=None): ...

根据函数的不同,你可以访问不同的变量。例如,在函数on_epoch_begin中,该函数既可以访问epoch编号,也可以访问当前度量、日志的字典。如果需要其他信息,比如学习率,可以使用keras.backend.get_value.

然后,可以像对待其他回调函数一样对待你自定义的回调函数。

model.fit(X_train, y_train, epochs=15, callbacks=[CustomCallback()])

自定义回调的一些常见想法:

  • 在JSON或CSV文件中记录训练结果。
  • 每10个epoch就通过电子邮件发送训练结果。
  • 在决定何时保存模型权重或者添加更复杂的功能。
  • 训练一个简单的机器学习模型(例如使用sklearn),通过将其设置为类变量并以(x: action, y: change)的形式获取数据,来学习何时提高或降低学习率。 

当在神经网络中使用回调函数时,你的控制力增强,神经网络变得更容易拟合。

 参考:人工智能 - 神经网络训练中回调函数的实用教程 - 个人文章 - SegmentFault 思否

本篇文章仅作为笔记使用,与君共勉

  • 6
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
### 回答1: EarlyStopping 回调函数是在训练神经网络模型时经常使用的一种回调函数。它可以帮助我们在训练过程监测模型的性能,并在模型性能不再提升时停止训练,从而避免过拟合。具体来说,EarlyStopping 回调函数会在每个 epoch 结束后计算验证集上的性能指标,例如准确率、损失等,并与之前的最佳性能指标进行比较。如果性能指标没有提升,则可以停止训练。 在 Keras ,可以通过在模型的 fit() 函数添加 EarlyStopping 回调函数来实现此功能。例如: ```python from keras.callbacks import EarlyStopping earlystop_callback = EarlyStopping(monitor='val_loss', patience=3) model.fit(x_train, y_train, validation_data=(x_val, y_val), epochs=10, callbacks=[earlystop_callback]) ``` 在这个例子,我们设置了 monitor 参数为 'val_loss',表示监测验证集上的损失函数。如果连续 3 个 epoch 验证集的损失函数没有提升,则停止训练。 通过使用 EarlyStopping 回调函数,我们可以有效地避免过拟合,并且可以在适当的时候停止训练,从而节省时间和计算资源。 ### 回答2: EarlyStopping回调函数是一种用于监控训练过程并自动停止训练的机制。它基于一定的条件判断模型的性能是否有进一步改善的可能性,如果没有,则提前终止训练,以防止过拟合,并节省训练时间和资源。 EarlyStopping回调函数通常使用验证集上的性能指标来判断模型的训练状态。在每个训练周期结束时,该回调函数会计算验证集上的指标,例如验证集上的损失函数值或准确率等。然后它与之前的最佳指标值进行比较,如果模型的性能有所提升,则更新最佳指标值并保存当前模型的权重。如果经过一定的训练周期,模型性能在验证集上没有提升,则可以判断模型已经达到了最优状态,此时可以停止训练使用最佳模型。 EarlyStopping回调函数有几个重要的参数可以设置。首先,可以设置一个监控指标,例如损失函数值或准确率等。其次,可以设置一个容忍度参数,在验证集上的性能没有改善的情况下,允许容忍一定的训练周期。最后,还可以设置一个参数来指定是否保存最佳模型的权重。 使用EarlyStopping回调函数可以帮助我们更好地控制模型的训练过程,避免过拟合和浪费资源。通过提前终止训练,我们可以节省时间和计算资源,并且得到一个在验证集上性能较好的模型。总之,EarlyStopping回调函数神经网络训练过程一种非常有用的机制,能够有效地提升训练效率和模型性能。 ### 回答3: EarlyStopping回调函数是一种用于在训练神经网络模型时提前停止训练的一种方法。这个回调函数通过监测模型的训练指标,例如验证集上的损失函数值或准确率,来判断模型是否已经达到了停止训练的条件。 在训练过程,如果模型的验证集上的性能在连续的一定轮数内没有改善,那么EarlyStopping回调函数会触发停止训练的操作。具体来说,这个函数会监测验证集上的损失函数值,如果连续若干个轮数内该损失函数值都没有显著下降,则判定模型已经达到了过拟合的程度,停止训练以防止模型的泛化性能继续下降。 当EarlyStopping回调函数触发停止训练操作时,可以通过设置参数来保存在训练过程获得的最好的模型参数。这样可以确保在训练结束后,可以使用具有最好性能的模型参数进行预测或测试。 使用EarlyStopping回调函数的好处是避免了过拟合,提高了模型的泛化能力。在训练迭代次数过多时,模型会过分拟合训练集数据,导致在验证集或测试集上的性能下降。而EarlyStopping回调函数通过监测验证集上的指标,实时判断模型是否过拟合,及时终止训练,可以有效提高模型的泛化能力。 总之,EarlyStopping回调函数是一种有效的训练技巧,用于在训练神经网络模型时提前停止训练,以防止过拟合的发生。通过监测验证集上的指标,它可以实时判断模型的性能,并在连续若干轮性能没有改善时停止训练。这样可以提高模型的泛化能力并减少过拟合的风险。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值