keras回调函数使用方法

keras可以使用回调函数来查看训练模型的内在状态,可以通过设置 callbacks 关键字参数的方式,传递一个列表的回调函数到  Model 类型的 .fit() 方法,然后模型在训练时,相应的回调函数会在各自的阶段被调用。

1. TensorBoard回调函数

keras.callbacks.TensorBoard(log_dir='./logs', histogram_freq=0, batch_size=32, write_graph=True, write_grads=False, write_images=False, embeddings_freq=0, embeddings_layer_names=None, embeddings_metadata=None, embeddings_data=None, update_freq='epoch')

TensorBoard 是 tensorflow 提供的一个可视化工具。该回调函数通过为 TensorBoard 编写一个日志, 以可视化测试和训练的标准评估的动态图像, 或可视化模型中不同层的激活值直方图。在安装了 tensorflow或tensorflow-gpu的前提下,可从命令行启动 TensorBoard : full_path_to_your_log是logs目录的绝对路径,不能是相对路径。

tensorboard --logdir=full_path_to_your_logs

参数含义:

  • log_dir: 用来保存被 TensorBoard 分析的日志文件的路径,可以使用相对路径。
  • histogram_freq: 对于模型中各个层计算激活值和模型权重直方图的频率(训练轮数中)。 如果设置成 0 ,直方图不会被计算。对于直方图可视化的验证数据(或分离数据)一定要明确指出。
  • write_graph: 是否在TensorBoard 中可视化图像。 如果 write_graph 被设置为 True,日志文件会变得非常大。
  • write_grads: 是否在 TensorBoard 中可视化梯度值直方图。 histogram_freq 必须大于 0 。
  • batch_size: 用作直方图计算的传入神经元网络输入批的大小。
  • write_images: 是否在 TensorBoard 中将模型权重以图片形式可视化。
  • embeddings_freq: 被选中的嵌入层会被保存的频率(在训练轮中)。
  • embeddings_layer_names: 一个列表,存放将被监测的网络层的名字。 如果是 None 或空列表,那么所有的嵌入层都会被监测。
  • embeddings_metadata: 一个字典,层的名字 对应 保存有这个嵌入层元数据文件的名字。 以防同样的元数据被用于所有的嵌入层,字符串可以被传入。
  • embeddings_data: 要嵌入在 embeddings_layer_names 指定的层的数据。 Numpy 数组(如果模型有单个输入)或 Numpy 数组列表(如果模型有多个输入)。
  • update_freq: 'batch' 或 'epoch' 或 整数。当使用 'batch' 时,在每个 batch 之后将损失和评估值写入到 tensorboard 中。当使用'epoch' 时类似。如果使用整数,例如 10000,这个回调会在每 10000 个样本之后将损失和评估值写入到 tensorboard 中。注意,频繁往tensorboard 写入数据会减缓训练速度。

2. ModelCheckpoint回调函数

keras.callbacks.ModelCheckpoint(filepath,monitor='val_loss',verbose=0,save_best_only=False, save_weights_only=False, mode='auto', period=1) 

参数含义:

  • filename:字符串,保存模型的路径。
  • monitor:需要监视的值。
  • verbose:信息展示模式,0或1(checkpoint的保存信息,类似Epoch 00001: saving model to ...)。
  • save_best_only:当设置为True时,监测值有改进时才会保存当前的模型( the latest best model according to the quantity monitored will not be overwritten)。
  • mode:‘auto’,‘min’,‘max’之一,在save_best_only=True时决定性能最佳模型的评判准则。例如,当监测值为val_acc时,模式应为max;当监测值为val_loss时,模式应为min;在auto模式下,评价准则由被监测值的名字自动推断。
  • save_weights_only:若设置为True,则只保存模型权重,否则将保存整个模型(包括模型结构,配置信息等)。
  • period:CheckPoint之间间隔的epoch数。

3. EarlyStopping函数

keras.callbacks.EarlyStopping(monitor='val_loss', min_delta=0, patience=0, verbose=0, mode='auto', baseline=None, restore_best_weights=False)

作用:

EarlySopping是用于提前停止训练的callbacks,可以实现当训练集上的loss不在减小(即减小的程度小于某个阈值)的时候停止继续训练。因为当训练集上的loss不在减小时,继续训练可能导致测试集上的准确率下降。 而继续训练导致测试准确率下降的原因可能是:(1)过拟合 (2)学习率过大导致不收敛 (3)使用正则项的时候,loss的减少可能不是因为准确率增加导致的,而是因为权重大小的降低。

参数含义:

  • monitor: 监控的数据接口,有’acc’,’val_acc’,’loss’,’val_loss’等等。正常情况下如果有验证集,就用’val_acc’或者’val_loss’。但是因为笔者用的是5折交叉验证,没有单设验证集,所以只能用’acc’了。
  • min_delta:增大或减小的阈值,只有大于这个部分才算作improvement。这个值的大小取决于monitor,也反映了你的容忍程度。例如笔者的monitor是’acc’,同时其变化范围在70%-90%之间,所以对于小于0.01%的变化不关心。加上观察到训练过程中存在抖动的情况(即先下降后上升),所以适当增大容忍程度,最终设为0.003%。
  • patience:能够容忍多少个epoch内都没有improvement。这个设置其实是在抖动和真正的准确率下降之间做tradeoff。如果patience设的大,那么最终得到的准确率要略低于模型可以达到的最高准确率。如果patience设的小,那么模型很可能在前期抖动,还在全图搜索的阶段就停止了,准确率一般很差。patience的大小和learning rate直接相关。在learning rate设定的情况下,前期先训练几次观察抖动的epoch number,比其稍大些设置patience。在learning rate变化的情况下,建议要略小于最大的抖动epoch number。笔者在引入EarlyStopping之前就已经得到可以接受的结果了,EarlyStopping算是锦上添花,所以patience设的比较高,设为抖动epoch number的最大值。
  • mode: 就’auto’, ‘min’, ‘,max’三个可能。如果知道是要上升还是下降,建议设置一下。笔者的monitor是’acc’,所以mode=’max’。
  • min_delta和patience都和“避免模型停止在抖动过程中”有关系,所以调节的时候需要互相协调。通常情况下,min_delta降低,那么patience可以适当减少;min_delta增加,那么patience需要适当延长;反之亦然。

 

参考:

https://keras.io/zh/callbacks/

https://blog.csdn.net/breeze5428/article/details/80875323

https://blog.csdn.net/silent56_th/article/details/72845912

  • 0
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 4
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值