【pytorch EarlyStopping】深度学习之早停法入门·相信我,一篇就够。

这个方法更好的解决了模型过拟合问题。

EarlyStopping的原理是提前结束训练轮次来达到“早停“的目的,故训练轮次需要设置的大一点以求更好的早停(比如可以设置100epoch)。

首先,我们需要一个一个标识,可以采用'val_acc’、’val_loss’等等,这些量在每一个轮次中都会不断更新自己的值,也和模型的参数息息相关,所以我们想通过他们间接操作模型参数。以val_loss来说,当模型训练时可能会出现当val_loss到一定值的时候会出现回弹的情况,所以我们希望在他回弹之前结束模型的训练。 

早停法其实一共有3类停止标准,这里我们选用最简单的一种入门。话不多说,上代码!!!

import numpy as np
import torch

导入两个最基本的包就行,因为早停法是一种可以自己就写出来的算法!!!

参数有5个:

第一个patience:这个是当有连续的patience个轮次数值没有继续下降,反而上升的时候结束训练的条件(以val_loss为例)

第二个verbose:这个其实就是是否print一些值,可也不传参,因为他有默认值

第三个delta:这个就是控制对比是的”标准线“

第四个path:这个是权重保存路径,早停法会在每一轮次次产生最优解(就是val_loss继续减少)的时候保存当前的模型参数。注:只要保存路径不变,每一次保存在文件里面的参数都会覆盖上一次保存在文件里面的参数。

第五个trace_func:这个就是显示每一个轮次变化的数值的方式,默认print,也可以改成进度条显示(tqdm的对象)

class EarlyStopping:
   
    def __init__(self, patience=7, verbose=False, delta=0, path='weight7-stop.pth', trace_func=print):
       
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = np.Inf
        self.delta = delta
        self.path = path
        self.trace_func = trace_func

    def __call__(self, val_loss, model):

        score = -val_loss

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
        elif score < self.best_score + self.delta:
            self.counter += 1
            self.trace_func(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
            self.counter = 0

    def save_checkpoint(self, val_loss, model):
        '''Saves model when validation loss decrease.'''
        if self.verbose:
            self.trace_func(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).  Saving model ...')
        torch.save(model.state_dict(), self.path)
        self.val_loss_min = val_loss

重点就在中间那个__call__方法里面,比较的是这一轮的val_loss和之前最好的val_loss(可以加上一个数实现‘标准线’的‘上移’或者‘下移’)

实际应用与项目当中

这是我再积水检测项目中的代码的一部分。

我设置了patience为7.

epoch为200。(这个推荐小一点,因为太大没有意义,一定会过拟合的)

 注:本文使用的早停法源代码不是原创,取自github。

  • 8
    点赞
  • 49
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
### 回答1: PyTorch Early Stopping是一种用于训练深度学习模型的技术。当训练模型时,我们通常会设置一个固定的训练轮数,然后在每个训练轮结束后评估模型的性能。然而,有时模型在训练过程中会出现过拟合的情况,即模型在训练集上表现良好,但在未见过的数据上表现不佳。 为了解决过拟合问题,Early Stopping技术引入了一个称为“patience”的超参数。Patience表示在模型性能不再提升时需要等待的训练轮数。具体来说,当模型在超过patience个训练轮数后性能没有明显提升时,训练将被提前停止,从而避免了继续训练过拟合的模型。 实现Early Stopping的一种常见方法是使用验证集(validation set)。在每个训练轮结束后,将训练好的模型在验证集上进行评估,并记录模型的性能指标,如损失函数或准确率。如果模型的性能指标在连续的patience个训练轮中都没有明显提升,那么就说明模型已经达到了性能的极限,此时训练过程将被停止。 PyTorch提供了一些开源的工具库,如torch.optim和torchvision等,这些工具库中包含了Early Stopping的实现。在使用PyTorch进行深度学习模型训练时,我们可以通过在每个训练轮结束后进行模型性能评估,并比较连续的几次评估结果,来判断是否需要提前停止训练。 总而言之,PyTorch Early Stopping是一种用于避免模型过拟合的技术,通过在每个训练轮结束后评估模型性能,并设置适当的patience值,可以在训练过程中及时停止模型的训练,从而获得更好的泛化能力和性能。 ### 回答2: PyTorch是一个广泛使用的开源深度学习框架,而Early Stopping则是一种用于训练过程中自动停止模型训练的技术。PyTorch提供了一种方便的方法来实现Early StoppingEarly Stopping主要通过监控模型在验证集上的性能指标来判断是否停止训练。在训练过程中,可以在每个训练周期结束后对验证集进行评估,并根据评估结果来判断当前模型的性能是否有所提升。 通常情况下,可以设置一个patience参数,该参数表示如果在连续多个训练周期中性能指标没有提升,就认为模型已经停止改进,从而停止训练。在模型训练过程中,当连续多个周期中性能指标没有提升时,可以通过设置一个计数器来进行累计,并与patience进行比较。 当计数器达到patience时,可以选择在当前训练周期结束后停止训练,并保存最佳模型参数。这个最佳模型参数是根据验证集上的性能指标来确定的,通常是在训练过程中保存验证集上性能最好的模型参数。 通过使用Early Stopping技术,可以防止模型过拟合,并能更加高效地训练深度学习模型。PyTorch提供了一些库函数和回调函数,可以方便地在训练过程中实现Early Stopping。例如,可以使用`torchbearer`库来实现Early Stopping,并设置patience参数来控制训练的停止。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

活成自己的样子啊

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值