【Pytorch-从一团乱麻到入门】:6、Pytorch 选择最终模型的方式:save best model & early stop

在模型训练时一般会进行多轮,那么到底哪一轮训练出来的模型是最优的呢?如果在脚本中挑选出最合适的模型呢?

针对上述问题,一般会有如下几种解决方法;

1、最占用存储但是却是最稳妥的方法:每一轮的模型都保存,模型保存方式为:

torch.save(model, "model.pkl")

2、早停机制,即在训练时保存效果在一定范围内不再提升时的模型。

早停机制是一种正则化的手段,用于避免训练数据集上的过拟合。早期停止会跟踪验证损失(val_loss),如果损失连续几个 epoch 停止下降,训练就会停止。

pytorch提供了实现早停机制的相遇脚本:pytorchtool.py ,下载路径为:

GitHub - Bjarten/early-stopping-pytorch: Early stopping for PyTorch

其中的 EarlyStopping 类用于创建一个对象,以便在训练 PyTorch 模型时跟踪验证损失。每次验证丢失减少时,它都会保存模型的一个检查点,在EarlyStopping类中设置了patience参数,即在最后一次验证损失改善后,我们希望在中断训练循环之前等待多少个epochs,在等待了patience个epoch后,如果模型效果不下降,那么这次模型则被保存为best-model.

具体使用方式为:

下载earlystop 脚本,然后将脚本放在模型训练脚本同一路径下,具体使用方式为:

# import EarlyStopping

from pytorchtools import EarlyStopping


在模型训练中:

early_stopping = EarlyStopping(patience=20, verbose=True) ###20次都不下降则为best model

early_stopping(val_loss, model)

if early_stopping.early_stop:
            print("The Early stopping epoch is this:",epoch)
            #stop_epoch=epoch
            break
#####保留last  checkpoint with the best model
model.load_state_dict(torch.load('checkpoint.pt'))

程序会自动搜索loss不再下降的点,持续训练patience次后不再下降即停止并保存模型。

此部分参考如下:

原文链接:pytorchtools的使用-早停机制(EarlyStopping)_peacefairy的博客-CSDN博客_pytorch 早停

3、保存损失函数不再下降的模型/保存测试集准确性最高的模型/保存测试集AUC最高的模型

此方法实现起来较为简单,根据实际情况在模型训练时记录每次产生的loss/acc/AUC值,每次进行一次判断,例如:若loss变低,则保留此次模型,若loss值变高,则保持上次保存的模型。

备注:此方法容易选出极端模型结果,使用时需要小心。

具体使用方法为:

best_acc=0.0

#####复制模型的参数
best_model_wts=copy.deepcopy(model.state_dict())

###拷贝模型最高精度下的参数
if val_acc_all[-1]>best_acc:
    best_acc=val_acc_all[-1]
    best_model_wts=copy.deepcopy(model.state_dict())
    time_use=time.time()-since
    print("Train and val complete in {:.0f}m {:.0f}s".format(time_use//60,time_use%60))
###使用最好模型的参数
model.load_state_dict(best_model_wts)

  • 5
    点赞
  • 22
    收藏
    觉得还不错? 一键收藏
  • 4
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值