Early Stopping
训练深度学习神经网络的时候通常希望能获得最好的泛化性能,可以更好地拟合数据。但是所有的标准深度学习神经网络结构如全连接多层感知机都很容易过拟合。
当模型在训练集上表现很好,在验证集上表现很差的时候,我们认为模型出现了过拟合的情况,early stoppping 就是用来预防过拟合的一种方法,简单且有效。
原理
early stoppping 的原理是:当模型在验证集上的表现开始下降的时候,停止训练,这样就能避免继续训练导致过拟合的问题
缺点
如下图,模型在验证集上的表现可能咱短暂的变差之后有可能继续变好,并不是在验证集上的表现一旦变差就不会变好。early stoppping 主要是训练时间和泛化错误之间的权衡。
pytorch 实现
EarlyStopping 是用于提前停止训练的callbacks。具体地,可以达到当训练集上的loss不在减小(即减小的程度小于某个阈值)的时候停止继续训练。
初始化
- patience:自上次模型在验证集上损失降低之后等待的时间,此处设置为7
- verbose:当为False时,运行的时候将不显示详细信息
- counter:计数器,当其值超过patience时候,使用early stopping
- best_score:记录模型评估的最好分数
- early_step:决定模型要不要early stop,为True则停
- val_loss_min:模型评估损失函数的最小值,默认为正无穷(np.Inf)
- delta:表示模型损失函数改进的最小值,当超过这个值时候表示模型有所改进
class EarlyStopping:
def __init__(self, patience=7, verbose=False, delta=0):
self.patience = patience
self.verbose = verbose
self.counter = 0
self.best_score = None
self.early_stop = False
self.val_loss_min = np.Inf
self.delta = delta
保存模型
传入参数:val_loss、model 和 path
verbose 为True,则打印详细信息
函数作用:在 path 路径下,保存当前 model,并更新 val_loss_min 为当前 val_loss
def save_checkpoint(self, val_loss, model, path):
if self.verbose:
print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model ...')
torch.save(model.state_dict(), path + '/' + 'checkpoint.pth')
self.val_loss_min = val_loss
调用
定义__call__()方法,该方法的功能类似于在类中重载 () 运算符,使得类实例对象可以像调用普通函数那样,以“对象名()”的形式使用
-
初始化时,设定了
self.best_score = None
-
if 语句第一行判断 self.best_score 是否为初始值,如果是初始值,则将 score 赋值给 self.best_score ,然后调用save_checkpoint()函数保存
-
当目前分数比最好分数加 self.delta 小时,就认为模型没有改进,将 counter 计数器加1,当计数器值超过 patience 的时候,就令early_stop为True,让模型停止训练。
-
当目前分数比最好分数加 self.delta 大时,我们认为模型有改进,将目前分数赋值给最好分数,并将模型保存,令计数器归零。
def __call__(self, val_loss, model, path):
score = -val_loss
if self.best_score is None:
self.best_score = score
self.save_checkpoint(val_loss, model, path)
elif score < self.best_score + self.delta:
self.counter += 1
print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
if self.counter >= self.patience:
self.early_stop = True
else:
self.best_score = score
self.save_checkpoint(val_loss, model, path)
self.counter = 0
总体代码
总体代码如下:
class EarlyStopping:
def __init__(self, patience=7, verbose=False, delta=0):
self.patience = patience
self.verbose = verbose
self.counter = 0
self.best_score = None
self.early_stop = False
self.val_loss_min = np.Inf
self.delta = delta
def __call__(self, val_loss, model, path):
score = -val_loss
if self.best_score is None:
self.best_score = score
self.save_checkpoint(val_loss, model, path)
elif score < self.best_score + self.delta:
self.counter += 1
print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
if self.counter >= self.patience:
self.early_stop = True
else:
self.best_score = score
self.save_checkpoint(val_loss, model, path)
self.counter = 0
def save_checkpoint(self, val_loss, model, path):
if self.verbose:
print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model ...')
torch.save(model.state_dict(), path + '/' + 'checkpoint.pth')
self.val_loss_min = val_loss