early_stopping
逻辑
在训练模型时,先申请一个early_stopping类,patience 表明能容忍几次,delta表明能容忍在训练过程中val_loss 的上升的范围
每一个epoch都把val_loss和模型传入到该实例中,正常来说,随着训练的过程,val_loss应该跟train_loss一起变小,但过拟合时,train_loss 会降低,val_loss会升高,我们设置patience表示容忍几次升高就停止
源代码
外部调用
early_stopping = EarlyStopping(patience=self.args.patience, verbose=True)
for i in epoch:
early_stopping(val_loss, self.model, path)
earlystopping类
class EarlyStopping:
def __init__(self, patience=7, verbose=False, delta=0):
self.patience = patience
self.verbose = verbose
self.counter = 0
self.best_score = None
self.early_stop = False
self.val_loss_min = np.Inf
self.delta = delta
def __call__(self, val_loss, model, path,ret,opt):
score = -val_loss
if self.best_score is None:
self.best_score = score
self.save_checkpoint(val_loss, model, path,ret,opt)
elif score < self.best_score + self.delta:
self.counter += 1
print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
if self.counter >= self.patience:
self.early_stop = True
else:
self.best_score = score
self.save_checkpoint(val_loss, model, path,ret,opt)
self.counter = 0
def save_checkpoint(self, val_loss, model, path,ret,opt):
if self.verbose:
print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model ...')
torch.save(model.state_dict(), path + '/' + 'checkpoint.pth')
np.save(path + '/' + 'ret.npy',ret)
np.save(path + '/' + 'opt.npy',opt)
self.val_loss_min = val_loss