Early Stopping 早停法原理与实现

Early Stopping

训练深度学习神经网络的时候通常希望能获得最好的泛化性能,可以更好地拟合数据。但是所有的标准深度学习神经网络结构如全连接多层感知机都很容易过拟合

当模型在训练集上表现很好,在验证集上表现很差的时候,我们认为模型出现了过拟合的情况,early stoppping 就是用来预防过拟合的一种方法,简单且有效

原理

early stoppping 的原理是:当模型在验证集上的表现开始下降的时候,停止训练,这样就能避免继续训练导致过拟合的问题

缺点

如下图,模型在验证集上的表现可能咱短暂的变差之后有可能继续变好,并不是在验证集上的表现一旦变差就不会变好。early stoppping 主要是训练时间和泛化错误之间的权衡。
在这里插入图片描述

pytorch 实现

EarlyStopping 是用于提前停止训练的callbacks。具体地,可以达到当训练集上的loss不在减小(即减小的程度小于某个阈值)的时候停止继续训练。

初始化

  • patience:自上次模型在验证集上损失降低之后等待的时间,此处设置为7
  • verbose:当为False时,运行的时候将不显示详细信息
  • counter:计数器,当其值超过patience时候,使用early stopping
  • best_score:记录模型评估的最好分数
  • early_step:决定模型要不要early stop,为True则停
  • val_loss_min:模型评估损失函数的最小值,默认为正无穷(np.Inf)
  • delta:表示模型损失函数改进的最小值,当超过这个值时候表示模型有所改进
class EarlyStopping:
    def __init__(self, patience=7, verbose=False, delta=0):
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = np.Inf
        self.delta = delta

保存模型

传入参数:val_loss、model 和 path

verbose 为True,则打印详细信息
函数作用:在 path 路径下,保存当前 model,并更新 val_loss_min 为当前 val_loss

    def save_checkpoint(self, val_loss, model, path):
        if self.verbose:
            print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).  Saving model ...')
        torch.save(model.state_dict(), path + '/' + 'checkpoint.pth')
        self.val_loss_min = val_loss

调用

定义__call__()方法,该方法的功能类似于在类中重载 () 运算符,使得类实例对象可以像调用普通函数那样,以“对象名()”的形式使用

  1. 初始化时,设定了self.best_score = None

  2. if 语句第一行判断 self.best_score 是否为初始值,如果是初始值,则将 score 赋值给 self.best_score ,然后调用save_checkpoint()函数保存

  3. 当目前分数比最好分数加 self.delta 小时,就认为模型没有改进,将 counter 计数器加1,当计数器值超过 patience 的时候,就令early_stop为True,让模型停止训练

  4. 当目前分数比最好分数加 self.delta 大时,我们认为模型有改进,将目前分数赋值给最好分数,并将模型保存,令计数器归零。

    def __call__(self, val_loss, model, path):
        score = -val_loss
        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model, path)
        elif score < self.best_score + self.delta:
            self.counter += 1
            print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model, path)
            self.counter = 0

总体代码

总体代码如下:

class EarlyStopping:
    def __init__(self, patience=7, verbose=False, delta=0):
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = np.Inf
        self.delta = delta

    def __call__(self, val_loss, model, path):
        score = -val_loss
        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model, path)
        elif score < self.best_score + self.delta:
            self.counter += 1
            print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model, path)
            self.counter = 0

    def save_checkpoint(self, val_loss, model, path):
        if self.verbose:
            print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).  Saving model ...')
        torch.save(model.state_dict(), path + '/' + 'checkpoint.pth')
        self.val_loss_min = val_loss

  • 1
    点赞
  • 14
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值