早停止(Early Stopping)-PyTorch版本(代码使用教程)

本文介绍了如何在PyTorch中使用早停法来解决过拟合问题,通过监控验证loss并设置耐心阈值,当验证loss不再降低时自动停止训练。作者提供了GitHub链接和代码示例,展示了早停法的实施步骤和效果。
摘要由CSDN通过智能技术生成

一、说明

记录自己使用早停法来解决过拟合问题的经历。

这里给出的是pytorch版本,需要tensorflow版本的,可以使用chatgpt转换,也可以自己转换。

二、原理

这个早停法的原理就是,对比你每次的验证loss,如果超过20epoch(可以自己设置)都没有上升,就认为过拟合了,就会跳出循环,停止代码。

ex:假如你跑到43个epoch,验证loss=0.02,等到跑到63个epoch时,验证loss始终没有低于0.01,就会停止代码

三、代码实现

首先提供原github链接,有兴趣的朋友可以看一下这个。

https://github.com/Bjarten/early-stopping-pytorch

如果不能进github下载的话,下面是我下载好的链接。

蓝奏云:https://wwqg.lanzouj.com/igmg71d0dicd

1.我先介绍一下这个参数

  • patience:上次模型在验证集上损失降低之后等待的epoch,此处设置为20。
  • verbose:默认为False,是否显示loss具体下降了多少。
  • counter:用于统计等待了多少个epoch,大于你设置的patience之后,就会停止循环。
  • best_score:记录模型评估的最好loss,用于比较。
  • early_step:默认为False,如果为True,模型就停止循环。
  • val_loss_min:默认为正无穷(np.Inf), 模型评估损失函数的最小值。
  • delta:默认为0, 表示模型损失函数改进的最小值,当超过这个值时候表示模型有所改进。

2.现在讲一下怎么使用。

1.首先,将github链接中的“pytorchtools.py”文件加入到你的项目中。

2.然后这样。

在开头引用

这一步其实可以不用,直接在调用的时候给参数。

不过,如果你代码中有这个参数设置的话,就写一行这个代码,如果没有的话,就不用写了。

注:如果你代码中有很多个要给参数的地方,这样写每次改一下就可以,否则改一下参数,要改很多个地方。

这个调用不要放到每次epoch的循环里,放到外面。如果没有刚才那个参数设置,这里args.patience就替换成你要设置参数。

下面这段代码放到epoch循环里面的最后位置就可以了

注:for循环就是类似这种:

for epoch in range(args.epoch):

意思就是说,每次循环都会检查一下是否符合早停的条件,如果符合就跳出循环,否则就继续循环。

以上就是早停法的使用流程了,希望可以帮到你。

四、效果图

  • 8
    点赞
  • 13
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 5
    评论
PyTorch EarlyStopping 是一个用于在训练过程中提前停止模型训练的技术。当模型在训练过程中出现过拟合或者性能不再提升时,EarlyStopping 可以帮助我们停止训练,以避免过拟合并节省时间和计算资源。 在 PyTorch 中,我们可以通过自定义一个 EarlyStopping 类来实现这个功能。以下是一个简单的示例代码: ```python import numpy as np import torch class EarlyStopping: def __init__(self, patience=5, delta=0): self.patience = patience self.delta = delta self.best_loss = np.Inf self.counter = 0 self.early_stop = False def __call__(self, val_loss): if val_loss < self.best_loss - self.delta: self.best_loss = val_loss self.counter = 0 else: self.counter += 1 if self.counter >= self.patience: self.early_stop = True return self.early_stop ``` 在训练过程中,我们可以使用 EarlyStopping 类来监测验证集的损失值,并在满足停止条件时停止训练。例如: ```python # 创建 EarlyStopping 实例 early_stopping = EarlyStopping(patience=3) for epoch in range(num_epochs): # 训练模型 # 在验证集上计算损失值 val_loss = calculate_validation_loss(model, validation_data) # 检查是否满足停止条件 if early_stopping(val_loss): print("Early stopping") break # 继续训练 ``` 在上述示例中,`patience` 参数表示允许验证集损失连续 `patience` 个 epoch 没有下降的次数,`delta` 参数表示损失值必须至少下降 `delta` 才会被认为是有明显改进。如果连续 `patience` 次都没有达到这个改进,训练将被停止。 这就是 PyTorch EarlyStopping 的基本用法,它可以帮助我们更加高效地训练模型,并避免过拟合。
评论 5
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

WinterWanderer

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值