一、什么是回调函数
回调函数是神经网络训练的重要组成部分
回调函数是在特定事件发生时被调用的函数。在机器学习中,回调函数通常与模型的训练过程相关联,并在训练的不同阶段触发执行特定的操作。回调函数提供了一种可以自定义和控制模型训练过程的方式,可以根据需要在训练过程中添加额外的功能或逻辑。
回调操作可以在训练的各个阶段执行,可能是在epoch之间,在处理一个batch之后,甚至在满足某个条件的情况下。回调可以利用许多创造性的方法来改进训练和性能,节省计算资源,并提供有关神经网络内部发生的事情的结论。
二、常见的回调函数
1、早停法(EarlyStopping)
早停法是一种常用的回调函数,用于在训练过程中监控模型的验证集损失(或其他指标),并在损失不再改善时停止训练,以防止过拟合,可以非常有助于防止在训练模型时产生额外的冗余运行。冗余运行会导致高昂的计算成本。当网络在给定的时间段内没有得到改善时,网络完成训练并停止使用计算资源。早停法根据设置的条件来判断是否停止训练。常见的条件包括:
monitor
:需要监控的指标,例如验证集的损失('val_loss')或准确率('val_accuracy')。patience
:当指标在多个训练周期内没有改善时,停止训练。即连续patience
个周期中,指标没有改善,则停止训练。mode
:指标的改善判断模式,可以是最小化('min')或最大化('max')。例如,如果monitor='val_loss'
且mode='min'
,则当验证集损失不再减小时,停止训练。
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
from tensorflow.keras.callbacks import EarlyStopping
# 创建模型
model = Sequential()
model.add(Dense(64, activation='relu', input_dim=10))
model.add(Dense(64, activation='relu'))
model.add(Dense(1, activation='sigmoid'))
# 编译模型
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
# 加载数据
# x_train 和 y_train 是训练数据的特征和标签
# 定义早停法回调函数
early_stopping = EarlyStopping(monitor='val_loss', patience=5, mode='min')
# 训练模型
model.fit(x_train, y_train, validation_split=0.2, epochs=100, batch_size=32, callbacks=[early_stopping])
在上述代码中,我们首先创建了一个序列模型,然后通过添加全连接层构建了一个简单的神经网络模型。接下来,我们使用
compile
方法来编译模型,指定了优化器、损失函数和评估指标。在训练之前,我们定义了早停法回调函数
EarlyStopping
。在这个例子中,我们设置了监控指标为验证集损失(monitor='val_loss'
),当验证集损失在连续5个周期内没有改善时,训练将会提前停止(patience=5
)。这里的mode='min'
表示我们希望最小化验证集损失。最后,我们在训练模型时传入了回调函数列表
callbacks=[early_stopping]
,使得早停法回调函数在每个训练周期结束时被调用。这样,当满足早停法的条件时,训练将会自动停止。
2、学习率衰减(ReduceLROnPlateau)
学习率衰减是一种回调函数,用于在训练过程中动态地调整学习率。学习率衰减的目的是在训练的早期使用较大的学习率以快速收敛,而在训练的后期逐渐降低学习率,使模型更加稳定并有更好的收敛性能。学习率衰减的主要参数包括:
monitor
:需要监控的指标,通常是验证集的损失('val_loss')或验证集的准确率('val_accuracy')。factor
:学习率每次衰减的因子,通常在0和1之间。patience
:当指标在多个训练周期内没有改善时,减小学习率。即连续patience
个周期中,指标没有改善,则减小学习率。min_lr
:学习率的下限,当学习率衰减到该值以下时,停止衰减。
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
from tensorflow.keras.callbacks import ReduceLROnPlateau
# 创建模型
model = Sequential()
model.add(Dense(64, activation='relu', input_dim=10))
model.add(Dense(64, activation='relu'))
model.add(Dense(1, activation='sigmoid'))
# 编译模型
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
# 加载数据
# x_train 和 y_train 是训练数据的特征和标签
# 定义学习率衰减回调函数
lr_decay = ReduceLROnPlateau(monitor='val_loss', factor=0.8, patience=3, min_lr=0.00001)
# 训练模型
model.fit(x_train, y_train, validation_split=0.2, epochs=100, batch_size=32, callbacks=[lr_decay])
在上述代码中,我们创建了一个序列模型,并通过添加全连接层构建了一个简单的神经网络模型。然后,我们使用
compile
方法来编译模型,指定了优化器、损失函数和评估指标。接下来,我们定义了一个学习率衰减回调函数
ReduceLROnPlateau
。在这个例子中,我们设置了监控指标为验证集损失(monitor='val_loss'
),当验证集损失在连续3个周期内没有改善时,学习率将会以指定的因子(factor=0.8
)进行衰减。衰减过程会在每次触发时发生,直到学习率达到最小值(min_lr=0.00001
)。最后,我们在训练模型时传入了回调函数列表
callbacks=[lr_decay]
,使得学习率衰减回调函数在每个训练周期结束时被调用,并相应地更新学习率。
3、NaN终止(TerminateOnNaN)
TerminateOnNaN
是一个Keras回调函数,用于在训练过程中检测并终止训练过程中出现的Na(Not a Number)值。
在深度学习中,NaN值通常表示数值计算溢出或发散的情况。当模型出现NaN值时,它表明模型的参数或梯度可能已经变得不稳定,无法继续有效地进行训练。为了避免继续训练出现无效的模型,可以使用TerminateOnNaN
回调函数来监测并终止训练。
TerminateOnNaN有助于防止在训练中产生梯度爆炸问题,因为输入NaN会导致网络的其他部分发生爆炸。如果不采用TerminateOnNaN,Keras并不阻止网络的训练。另外,nan会导致对计算能力的需求增加。为了防止这些情况发生,添加TerminateOnNaN是一个很好的安全检查。
当使用TerminateOnNaN
回调函数时,在每个训练周期结束后,回调函数将检查模型的损失值是否为NaN。如果损失值为NaN,则训练过程将被立即终止,并引发一个异常来停止训练。
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
from tensorflow.keras.callbacks import TerminateOnNaN
# 创建模型
model = Sequential()
model.add(Dense(64, activation='relu', input_dim=10))
model.add(Dense(64, activation='relu'))
model.add(Dense(1, activation='sigmoid'))
# 编译模型
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
# 加载数据
# x_train 和 y_train 是训练数据的特征和标签
# 定义TerminateOnNaN回调函数
nan_termination = TerminateOnNaN()
# 训练模型
model.fit(x_train, y_train, epochs=100, batch_size=32, callbacks=[nan_termination])
在上述代码中,我们创建了一个序列模型,并通过添加全连接层构建了一个简单的神经网络模型。然后,我们使用
compile
方法来编译模型,指定了优化器、损失函数和评估指标。接下来,我们定义了一个
TerminateOnNaN
回调函数nan_termination
。最后,我们在训练模型时传入了回调函数列表
callbacks=[nan_termination]
,使得TerminateOnNaN
回调函数在每个训练周期结束时被调用。如果训练过程中出现NaN值,训练将会被立即终止,并引发一个异常。通过使用
TerminateOnNaN
回调函数,可以帮助我们识别并及时处理训练过程中的数值不稳定问题,确保训练过程的有效性和稳定性。
4、模型权重保存(ModelCheckpoint)
ModelCheckpoint
是一个Keras回调函数,用于在训练过程中保存模型的权重。
在深度学习中,模型的训练过程通常需要进行多个训练周期(epochs),而每个周期结束后的模型状态可能是不同的。为了保存在训练过程中得到的最佳模型或定期保存模型的权重,可以使用ModelCheckpoint
回调函数。
ModelCheckpoint
回调函数允许您定义保存模型权重的条件和设置。以下是一些常用的参数:
filepath
:表示模型权重保存的文件路径。可以使用格式化字符串指定文件名,例如weights.{epoch:02d}-{val_loss:.2f}.h5
,其中{epoch:02d}
表示当前训练周期的数字,{val_loss:.2f}
表示验证集损失的浮点数值。monitor
:表示要监测的指标,如val_loss
或val_accuracy
。当这个指标在训练过程中有改善时,模型的权重将被保存。save_best_only
:一个布尔值,表示是否仅保存在验证集上性能最好的模型权重。如果设置为True
,则只有当监测的指标有改善时才会保存模型权重。save_weights_only
:一个布尔值,表示是否仅保存模型的权重而不保存模型的结构。如果设置为True
,则只保存权重,否则将保存整个模型(包括结构和权重)。save_freq
:表示保存模型权重的频率。可以设置为整数表示每隔多少个样本保存一次,或者设置为字符串'epoch'
表示每个训练周期结束后保存一次。
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
from tensorflow.keras.callbacks import ModelCheckpoint
# 创建模型
model = Sequential()
model.add(Dense(64, activation='relu', input_dim=10))
model.add(Dense(64, activation='relu'))
model.add(Dense(1, activation='sigmoid'))
# 编译模型
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
# 加载数据
# x_train 和 y_train 是训练数据的特征和标签
# 定义ModelCheckpoint回调函数
checkpoint = ModelCheckpoint(filepath='weights.{epoch:02d}-{val_loss:.2f}.h5',
monitor='val_loss',
save_best_only=True,
save_weights_only=False,
save_freq='epoch')
# 训练模型
model.fit(x_train, y_train, validation_split=0.2, epochs=100, batch_size=32, callbacks=[checkpoint])
在上述代码中,我们创建了一个序列模型,并通过添加全连接层构建了一个简单的神经网络模型。然后,我们使用
compile
方法来编译模型,指定了优化器、损失函数和评估指标。接下来,我们定义了一个
ModelCheckpoint
回调函数checkpoint
。我们指定了保存模型权重的文件路径(filepath
),监测指标为验证集损失(monitor='val_loss'
),仅保存在验证集上性能最好的模型权重(save_best_only=True
),保存整个模型而不仅仅是权重(save_weights_only=False
),并在每个训练周期结束后保存模型权重(save_freq='epoch'
)。最后,我们在训练模型时传入了回调函数列表
callbacks=[checkpoint]
,使得ModelCheckpoint
回调函数在每个训练周期结束时被调用,并根据设定的条件保存模型权重。通过使用
ModelCheckpoint
回调函数,可以方便地保存模型的权重,并在训练过程中选择性地保存最佳模型或定期保存模型的权重,以备后续使用。
三、总结
通过使用回调函数,可以在模型训练过程中实现更多的灵活性和自定义功能。例如,可以根据验证集的损失来选择最佳的模型参数,或者在训练过程中动态调整学习率以提高模型的收敛速度和性能。回调函数使得训练过程更加可控和高效,并提供了一种灵活的方式来处理各种训练中的需求和场景。
但是编写自定义回调是Keras包含的最好的特性之一,它允许执行高度特定的操作。但是,请注意,构造它比使用默认回调要复杂得多。
我们的自定义回调将采用类的形式。类似于在PyTorch中构建神经网络,我们可以继承keras.callbacks.Callback回调,它是一个基类。
我们的类可以有许多函数,这些函数必须具有下面列出的给定名称以及这些函数将在何时运行。例如,将在每个epoch开始时运行on_epoch_begin函数。下面是Keras将从自定义回调中读取的所有函数,但是可以添加其他“helper”函数。
class CustomCallback(keras.callbacks.Callback): #继承keras的基类
def on_train_begin(self, logs=None):
#日志是某些度量的字典,例如键可以是 ['loss', 'mean_absolute_error']
def on_train_end(self, logs=None): ...
def on_epoch_begin(self, epoch, logs=None): ...
def on_epoch_end(self, epoch, logs=None): ...
def on_test_begin(self, logs=None): ...
def on_test_end(self, logs=None): ...
def on_predict_begin(self, logs=None): ...
def on_predict_end(self, logs=None): ...
def on_train_batch_begin(self, batch, logs=None): ...
def on_train_batch_end(self, batch, logs=None): ...
def on_test_batch_begin(self, batch, logs=None): ...
def on_test_batch_end(self, batch, logs=None): ...
def on_predict_batch_begin(self, batch, logs=None): ...
def on_predict_batch_end(self, batch, logs=None): ...
根据函数的不同,你可以访问不同的变量。例如,在函数on_epoch_begin中,该函数既可以访问epoch编号,也可以访问当前度量、日志的字典。如果需要其他信息,比如学习率,可以使用keras.backend.get_value.
然后,可以像对待其他回调函数一样对待你自定义的回调函数。
model.fit(X_train, y_train, epochs=15, callbacks=[CustomCallback()])
自定义回调的一些常见想法:
- 在JSON或CSV文件中记录训练结果。
- 每10个epoch就通过电子邮件发送训练结果。
- 在决定何时保存模型权重或者添加更复杂的功能。
- 训练一个简单的机器学习模型(例如使用sklearn),通过将其设置为类变量并以(x: action, y: change)的形式获取数据,来学习何时提高或降低学习率。
当在神经网络中使用回调函数时,你的控制力增强,神经网络变得更容易拟合。
参考:人工智能 - 神经网络训练中回调函数的实用教程 - 个人文章 - SegmentFault 思否
本篇文章仅作为笔记使用,与君共勉