Keras 回调函数
回调函数的作用
回调函数作用:在模型的训练过程中,可以用来监控并且能够干预模型的训练。 这里有一个形象的比喻:
仅仅使用model.fit()或model.fit_generator()开启训练有点像往天空扔一架纸飞机,你用力抛以后,就无法控制它的飞行轨迹,在哪里降落。我们这里利用回调函数,就有点是我们在飞一架无人机,我们可以获得无人机的数据,同时我们也可以操控无人机的飞行轨迹。
回调函数作用于模型
我们刚刚将没有使用的回调函数的model.fit()
比喻成不受控制的纸飞机,那么我们想要模型在训练的过程中受控制,我们需要怎么来控制,这里我们可以想到的是如下的点: - 中断训练 - 保存模型 - 加载不同的权重组 - 改变模型的状态 - ...
这些都是我们想向模型训练时施加的控制,那么接下来我们要问的问题是,我何时需要中断模型,何时需要保存模型以及keras中有那些函数来帮助我实现这些控制。 - Model chekpoint 模型检查点:在每个训练期之后保存模型 - Early Stopping 提前终止:当被监测的数量不再提升,则停止训练 - 动态调整训练的参数:LearningRateScheduler 学习率的优化 - 在训练过程中记录训练指标:CSVLogger
这里keras中提高较多的函数来操控,主要的应用的函数如下: - keras.callbacks.ModelCheckpoint - keras.callbacks.EarlyStopping - keras.callbacks.LearningRateScheduler - keras.callbacks.ReduceLROnPlateau - keras.callbacks.CSVLogger - ...
具体更多的函数参考keras文档
现在我们知道这么多的函数那么如何来使用他们,以及什么时候使用这些函数才合适呢?
回调函数是一个函数的合集,会在训练的阶段中所使用。你可以使用回调函数来查看训练模型的内在状态和统计。
这里先说明一个通用的过程,就是在fit
函数中,将callback函数作为一个对象参数传入其中,这样我们在训练的时候,就会根据不同的时间点,模型会去调用相应的函数。
具体的我们接着往下看。
ModelCheckpoint 与 EarlyStopping回调函数
对于EarlyStopping回调函数,最好的使用场景就是,如果我们发现经过了数轮后,目标指标不再有改善了,就可以提前终止,这样就节省时间。
该函数的具体参数如下:
keras.callbacks.EarlyStopping( monitor='val_loss', min_delta=0, patience=0, verbose=0, mode='auto', baseline=None, restore_best_weights=False)
- monitor: 被监测的数据。
- min_delta: 在被监测的数据中被认为是提升的最小变化, 例如,小于 min_delta 的绝对变化会被认为没有提升。
- patience: 没有进步的训练轮数,在这之后训练就会被停止。
- verbose: 详细信息模式。
- mode: {auto, min, max} 其中之一。 在 min 模式中, 当被监测的数据停止下降,训练就会停止;在 max 模式中,当被监测的数据停止上升,训练就会停止;在 auto 模式中,方向会自动从被监测的数据的名字中判断出来。
- baseline: 要监控的数量的基准值。 如果模型没有显示基准的改善,训练将停止。
- restore_best_weights: 是否从具有监测数量的最佳值的时期恢复模型权重。 如果为 False,则使用在训练的最后一步获得的模型权重。
在中断的时候,我们可以结合ModelCheckpoint来保存模型,这样我们可以保证只保存的是最佳模型。
keras.callbacks.ModelCheckpoint(filepath, monitor='val_loss', verbose=0, save_best_only=False, save_weights_only=False, mode='auto', period=1)
- filepath: 字符串,保存模型的路径。
- monitor: 被监测的数据。
- verbose: 详细信息模式,0 或者 1 。
- save_best_only: 如果 save_best_only=True, 被监测数据的最佳模型就不会被覆盖。
- mode: {auto, min, max} 的其中之一。 如果 save_best_only=True,那么是否覆盖保存文件的决定就取决于被监测数据的最大或者最小值。 对于 val_acc,模式就会是 max,而对于 val_loss,模式就需要是 min,等等。 在 auto 模式中,方向会自动从被监测的数据的名字中判断出来。
- save_weights_only: 如果 True,那么只有模型的权重会被保存 (model.save_weights(filepath)), 否则的话,整个模型会被保存 (model.save(filepath))。
- period: 每个检查点之间的间隔(训练轮数)。
具体的coding如下:
import keras
callbacks_list = [
# 目标指标不再有改善了,就可以提前终止
keras.callbacks.EarlyStopping(
monitor='acc', # 被监测的模型的精度
patience=1 # 没有进步的训练轮数为1,在这之后训练就会被停止
),
# 保存模型
keras.callbacks.ModelCheckpoint(
filepath = 'my_model.h5', # 文件路径
monitor='val_loss', # 如果val_loss 没有改善就不覆盖
save_best_only=True) # 保持最佳模型
]
model.compile(optimizer='rmsprop',
loss='binary_crossentropy',
metrics=['acc'])
model.fit(x,y,
epochs=10,
batch_size=32,
callbacks=callbacks_list,
validation_data=(x_val,y_val))
ReduceLROnPlateau 回调函数
当标准评估停止提升时,降低学习速率。这样做的目的是通过降低或提高学习率来跳出局部最小值。这样保证模型继续训练下去,达到最优解。
具体参数如下
keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=10, verbose=0, mode='auto', min_delta=0.0001, cooldown=0, min_lr=0)
- monitor: 被监测的数据。
- factor: 学习速率被降低的因数。新的学习速率 = 学习速率 * 因数
- patience: 没有进步的训练轮数,在这之后训练速率会被降低。
- verbose: 整数。0:安静,1:更新信息。
- mode: {auto, min, max} 其中之一。如果是 min 模式,学习速率会被降低如果被监测的数据已经停止下降; 在 max 模式,学习塑料会被降低如果被监测的数据已经停止上升; 在 auto 模式,方向会被从被监测的数据中自动推断出来。
- min_delta: 对于测量新的最优化的阀值,只关注巨大的改变。
- cooldown: 在学习速率被降低之后,重新恢复正常操作之前等待的训练轮数量。
- min_lr: 学习速率的下边界。
# 定义回调函数集合
callbacks_list = [
#
keras.callbacks.ReduceLROnPlateau(
monitor='val_loss' # 监控模型的验证损失
factor=0.1, # 触发时将学习率除以 10
patience=10, #如果验证损失在 10 轮内都没有改善,那么就触发这个回调函数
)
]
model.fit(x, y,
epochs=10,
batch_size=32,
callbacks=callbacks_list,
validation_data=(x_val, y_val))
分享关于人工智能,机器学习,深度学习以及计算机视觉的好文章,同时自己对于这个领域学习心得笔记。想要一起深入学习人工智能的小伙伴一起结伴学习吧!扫码上车!