Keras防止过拟合(三) 如何提前终止训练

解决过拟合的方法,前面已经讲了2种,Dropout层Keras防止过拟合(一)Dropout层源码细节,L1L2正则化keras防止过拟合(二) L1正则化与L2正则化源码细节和在自定义层中加入
除此之外,当损失函数不降反增,或是降低十分缓慢时,提前结束训练也是一个很好的办法。因为,随着模型训练次数的增多,模型会更加复杂,更易出现过拟合。本篇说明一下,keras如何提前终止训练。

使用回调函数callbacks中的EarlyStopping

简单的例子:

callback = keras.callbacks.EarlyStopping(monitor='loss', patience=1)#使用loss作为监测数据,轮数设置为1
model = Sequential()
model.add(Dense(10))
model.compile(loss='categorical_crossentropy', optimizer='sgd', loss='mse')
model.fit(x_train, y_train,epochs=10, batch_size=1, callbacks=[callback])

Keras使用fit函数对模型进行训练:

fit(x=None, y=None, batch_size=None, epochs=1, verbose=1, callbacks=None, validation_split=0.0, validation_data=None, shuffle=True, class_weight=None, sample_weight=None, initial_epoch=0, steps_per_epoch=None, validation_steps=None)

观察fit函数中的参数,其中有一项为callbacks,keras中文文档中介绍如下:
在这里插入图片描述
回调函数callbacks:
在这里插入图片描述
在这些回调函数中,有一项为EarlyStopping,用来提前终止训练:

keras.callbacks.EarlyStopping(monitor='val_loss', min_delta=0, patience=0, verbose=0, mode='auto', baseline=None, restore_best_weights=False)

在这里插入图片描述
其中,monitor参数指定被监测的数据(个人感觉这个翻译有问题,应该叫做被监测的参数)。默认为’val_loss’,验证集的损失值。可以根据自己的需求,指定被监测数据。min_delta参数指定最小变化,默认值为0,若变化小于min_delta,模型停止训练。
patience指定没有进步的轮数。
工作原理:若在patience轮数中,指定的被检测数据,提升的值小于min_delta,模型便停止训练。

关于mode参数的理解:有些被检测数据,例如loss,值越小越好,因此使用min模式。但有些被检测数据,例如Recall,Accuracy,值越大越好,因此使用max模式。关于auto模式,描述中是从被监测的数据的名字中自动判断是使用max还是min。但我观察源码后,发现其只能自动使Accuracy变为max,其余的都是min.因此,如果你指定的是Precision、Recall等,必须定义mode为max,否则,mode默认定义为auto,还是会认为值越小越好,因此会出现,第二轮就结束(patience为1的情况)。
Earlystopping部分源码:

if mode == 'min':
    self.monitor_op = np.less
elif mode == 'max':
    self.monitor_op = np.greater
else:
    if 'acc' in self.monitor:#可以自动识别acc,其余的都不行。
        self.monitor_op = np.greater
    else:
        self.monitor_op = np.less

关于monitor被监测数据,具体可以使用哪些,keras中文文档并没有说,打开keras官方文档,其中有更加全面的描述:EarlyStopping
在这里插入图片描述
文档中说明了,制定的被监测数据,必须是compile中定义过的评价参数(这也是为什么我觉得monitor应该称为被监测参数的原因),也就是说,只要是在compile中定义过的评价参数,都可以用来作为被监测数据。

可使用监测数据

使用keras时,定义好模型后,需要先配置训练模型,使用到的是compile函数:

compile(optimizer, loss=None, metrics=None, loss_weights=None, sample_weight_mode=None, weighted_metrics=None, target_tensors=None)

在这里插入图片描述
其中的metrics,用来指定评估模型用的参数。在这里定义过的,都可以作为EarlyStopping中的被监测数据。可使用的评价参数:
Metrics
需要注意的是,callback中monitor需要使用参数的字典名称,如损失为’loss’,准确率为’acc’。
若选择的被监测数据,没有在compile中定义,却在训练时指定,虽然不会终止训练,但是会给你提示:

model.compile(loss='categorical_crossentropy', optimizer='adam',
              metrics=[keras.metrics.Recall(top_k=10), keras.metrics.Precision(top_k=10)])
callback = keras.callbacks.EarlyStopping(monitor='auc', patience=1)
model.fit(x_train,y_train,epochs=3,batch_size=128,callbacks=[callback])

在这里插入图片描述

只在compile的metrics中加入recall,precision,没有加入auc,却使用auc作为提前终止训练的被监测数据。

需要注意的另一点,前面也说过,monitor默认为’val_loss’,是验证集的损失。但如果没有加入验证集呢?
在这里插描述

没有使用验证集,却使用了’val_loss’作为被监测数据

可以看到,虽然不会影响训练,但是由于你指定的被监测数据monitor不存在,所以没办法依此提前终止训练。
使用了验证集,才能指定验证集的评价参数作为被监测数据monitor。

  • 4
    点赞
  • 51
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值