early_stopping

early_stopping

逻辑

在训练模型时,先申请一个early_stopping类,patience 表明能容忍几次,delta表明能容忍在训练过程中val_loss 的上升的范围

每一个epoch都把val_loss和模型传入到该实例中,正常来说,随着训练的过程,val_loss应该跟train_loss一起变小,但过拟合时,train_loss 会降低,val_loss会升高,我们设置patience表示容忍几次升高就停止

源代码

外部调用
early_stopping = EarlyStopping(patience=self.args.patience, verbose=True)
for i in epoch:
	early_stopping(val_loss, self.model, path)
earlystopping类
class EarlyStopping:
    def __init__(self, patience=7, verbose=False, delta=0):
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = np.Inf
        self.delta = delta

    def __call__(self, val_loss, model, path,ret,opt):
        score = -val_loss
        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model, path,ret,opt)
        elif score < self.best_score + self.delta:
            self.counter += 1
            print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model, path,ret,opt)
            self.counter = 0

    def save_checkpoint(self, val_loss, model, path,ret,opt):
        if self.verbose:
            print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}).  Saving model ...')
        torch.save(model.state_dict(), path + '/' + 'checkpoint.pth')
        np.save(path + '/' + 'ret.npy',ret)
        np.save(path + '/' + 'opt.npy',opt)
        self.val_loss_min = val_loss
  • 0
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值