Keras 回调 Callback学习总结

回调函数是一个函数的集合,会在训练阶段中使用。一般使用回调函数来查看训练模型的内在状态和统计信息。

使用方法:传递一个列表的回调函数名称作为callbacks关键字到Sequential或Model类型的…fit()方法。在训练时,相应的回调函数的方法就会在各自的阶段被调用。

编程分为两类:系统编程(system programming)和应用编程(application programming)。所谓系统编程,简单来说,就是编写库;而应用编程就是利用写好的各种库来编写具某种功用的程序,也就是应用。系统程序员会给自己写的库留下一些接口,即API(application programming interface,应用编程接口),以供应用程序员使用。所以在抽象层的图示里,库位于应用的底下。当程序跑起来时,一般情况下,应用程序(application program)会时常通过API调用库里所预先备好的函数。但是有些库函数(library function)却要求应用先传给它一个函数,好在合适的时候调用,以完成目标任务。这个被传入的、后又被调用的函数就称为回调函数(callback function)。

在这里插入图片描述

回调函数的应用实例

下面介绍几种常见的回调函数的用法

1. 利用回调函数动态调整学习速率

在模型的训练过程中,常常需要根据训练状态动态调整学习速率。当然本身可自动调整学习速率的优化器,如:Adagrad、Adadelta、RMSprop、Adam等除外。

具体的,Keras提供了两种调整学习速率的方法:

  1. LearningRateScheduler()方法,其将训练轮次epoch作为参数定义方法来调整学习速率。利用此方法可实现基于训练轮次epoch而逐渐减小学习速率的目的;
  2. ReduceLROnPlateau()方法,其根据评价指标自动设置学习速率。即当评价指标不再提升时,减少学习速率。

1.1 LearningRateScheduler()

keras.callbacks.callbacks.LearningRateScheduler(schedule, verbose=0)
  1. schedule: 一个函数,接受轮索引数作为输入(整数,从 0 开始迭代) 然后返回一个学习速率作为输出(浮点数)。
  2. verbose: 整数。 0:安静,1:更新信息。

示例

import keras.backend as K  # backend 作为获取设置学习速率的工具
from keras.callbacks import LearningRateScheduler

def scheduler(epoch):
    # 每隔100个epoch,学习速率减小为原来的1/10
    if epoch % 100 == 0 and epoch != 0:
        lr = K.get_value(model.optimizer.lr)
        K.set_value(model.optimizer.lr, lr * 0.1)
    return K.get_value(model.optimizer.lr)

reduce_lr = LearningRateScheduler(scheduler)

model.fit(train_x, train_y, batch_size=32, epochs=300, callbacks=[reduce_lr])

1.2 ReduceLROnPlateau()

当标准评估停止提升时,降低学习速率。

当学习停止时,模型总是会受益于降低 2-10 倍的学习速率。 这个回调函数监测一个数据并且当这个数据在一定「有耐心」的训练轮之后还没有进步, 那么学习速率就会被降低。

keras.callbacks.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=10, verbose=0, mode='auto', min_delta=0.0001, cooldown=0, min_lr=0)
  1. monitor: 被监测的数据。
  2. factor: 学习速率被降低的因数。新的学习速率 = 学习速率 * 因数
  3. patience: 在监测质量经过多少轮次没有进度时即停止。如果验证频率 (model.fit(validation_freq=5)) 大于 1,则可能不会在每个轮次都产生验证数。
  4. verbose: 整数。0:安静,1:更新信息。
  5. mode: {auto, min, max} 其中之一。如果是 min 模式,学习速率会被降低如果被监测的数据已经停止下降; 在 max 模式,学习速率会被降低如果被监测的数据已经停止上升; 在 auto 模式,方向会被从被监测的数据中自动推断出来。
  6. min_delta: 对于测量新的最优化的阈值,只关注巨大的改变。
  7. cooldown: 在学习速率被降低之后,重新恢复正常操作之前等待的训练轮数量。
  8. min_lr: 学习速率的下边界。

实例

from keras.callbacks import ReduceLROnPlateau
redeuce_lr = ReduceLROnPlateau(monitor='val_loss', patience=10, mode='auto')
model.fit(train_x, train_y, batch_size=32, epochs=300, validation_split=0.1, callback=[reduce_lr])

2. 利用回调保存最佳模型

利用回调保存最佳模型需要用到ModelCheckpoint这个回调函数。其作用为:在每个训练期之后保存模型。

2.1 ModelCheckpoint()

keras.callbacks.callbacks.ModelCheckpoint(filepath, monitor='val_loss', verbose=0, save_best_only=False, save_weights_only=False, mode='auto', period=1)

  1. filepath: 字符串,保存模型的路径。
  2. monitor: 被监测的数据。
  3. verbose: 详细信息模式,0 或者 1 。
  4. save_best_only: 如果 save_best_only=True, 被监测数据的最佳模型就不会被覆盖。
  5. mode: {auto, min, max} 的其中之一。 如果save_best_only=True,那么是否覆盖保存文件的决定就取决于被监测数据的最大或者最小值。 对于 val_acc,模式就会是 max,而对于 val_loss,模式就需要是 min,等等。 在 auto 模式中,方向会自动从被监测的数据的名字中判断出来。
  6. save_weights_only: 如果 True,那么只有模型的权重会被保存 (model.save_weights(filepath)), 否则的话,整个模型会被保存 (model.save(filepath))。
  7. period: 每个检查点之间的间隔(训练轮数)。

实例

from keras.callbacks import LearningRateScheduler, ModelCheckpoint

# 保存最佳模型
filepath = 'weights.{epoch:02d}-{val_loss:.2f}.hdf5'
checkpoint = ModelCheckpoint(filepath, monitor='val_accuracy', verbose=1, save_best_only=True, mode='max')
# epochs迭代周期,图片全部训练一次为一周期
history = model.fit(x_train, y_train, batch_size=batch_size, epochs=epochs, callbacks=[lr_new, checkpoint], validation_data=(x_test, y_test))

保存最佳模型后,可能想要提前结束训练,这是就必须用到callbacks的另一个回调函数EarlyStopping()

2.2 EarlyStopping()

当被监测的数量不再提升,则停止训练。

keras.callbacks.callbacks.EarlyStopping(monitor='val_loss', min_delta=0, patience=0, verbose=0, mode='auto', baseline=None, restore_best_weights=False)
  1. monitor: 被监测的数据。
  2. min_delta: 在被监测的数据中被认为是提升的最小变化, 例如,小于 min_delta 的绝对变化会被认为没有提升。
  3. patience: 在监测质量经过多少轮次没有进度时即停止。如果验证频率 (model.fit(validation_freq=5)) 大于 1,则可能不会在每个轮次都产生验证数。
  4. verbose: 详细信息模式。
  5. mode: {auto, min, max} 其中之一。 在 min 模式中, 当被监测的数据停止下降,训练就会停止;在 max 模式中,当被监测的数据停止上升,训练就会停止;在 auto 模式中,方向会自动从被监测的数据的名字中判断出来。
  6. baseline: 要监控的数量的基准值。 如果模型没有显示基准的改善,训练将停止。
  7. restore_best_weights: 是否从具有监测数量的最佳值的时期恢复模型权重。 如果为 False,则使用在训练的最后一步获得的模型权重。
  • 1
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值