Pytorch-早停法(early stopping)原理及其代码

作为深度学习训练数据的trick,这个方法必须知道啊,结合交叉验证法,可以防止模型过早拟合。

早停法是一种被广泛使用的方法,在很多案例上都比正则化的方法要好。是在训练中计算模型在验证集上的表现,当模型在验证集上的表现开始下降的时候,停止训练,这样就能避免继续训练导致过拟合的问题。其主要步骤如下:
1. 将原始的训练数据集划分成训练集和验证集
2. 只在训练集上进行训练,并每隔一个周期计算模型在验证集上的误差
3. 当模型在验证集上(权重的更新低于某个阈值;预测的错误率低于某个阈值;达到一定的迭代次数),则停止训练
4. 使用上一次迭代结果中的参数作为模型的最终参数

如下图之后的某个epoch,模型的验证误差逐渐上升,模型出现过拟合,所以需要提前停止训练,早停法主要是训练时间和泛化错误之间的权衡。不同的停止标准也是给我们带来不同的效果。

下面在pytorch上面运用早停法(early stopping)

#Train the Model using Early Stopping
def train_model(model, batch_size, patience, n_epochs):
    
    # to track the training loss as the model trains
    train_losses = []
    # to track the validation loss as the model trains
    valid_losses = []
    # to track the average training loss per epoch as the model trains
    avg_train_losses = []
    # to track the average validation loss per epoch as the model trains
    avg_valid_losses = [] 
    
    # initialize the early_stopping object
    early_stopping = EarlyStopping(patience=patience, verbose=True)
    
    for epoch in range(1, n_epochs + 1):

        ###################
        # train the model #
        ###################
        model.train() # prep model for training
        for batch, (data, target) in enumerate(train_loader, 1):
            # clear the gradients of all optimized variables
            optimizer.zero_grad()
            # forward pass: compute predicted outputs by passing inputs to the model
            output = model(data)
            # calculate the loss
            loss = criterion(output, target)
            # backward pass: compute gradient of the loss with respect to model parameters
            loss.backward()
            # perform a single optimization step (parameter update)
            optimizer.step()
            # record training loss
            train_losses.append(loss.item())

        ######################    
        # validate the model #
        ######################
        model.eval() # prep model for evaluation
        for data, target in valid_loader:
            # forward pass: compute predicted outputs by passing inputs to the model
            output = model(data)
            # calculate the loss
            loss = criterion(output, target)
            # record validation loss
            valid_losses.append(loss.item())

        # print training/validation statistics 
        # calculate average loss over an epoch
        train_loss = np.average(train_losses)
        valid_loss = np.average(valid_losses)
        avg_train_losses.append(train_loss)
        avg_valid_losses.append(valid_loss)
        
        epoch_len = len(str(n_epochs))
        
        print_msg = (f'[{epoch:>{epoch_len}}/{n_epochs:>{epoch_len}}] ' +
                     f'train_loss: {train_loss:.5f} ' +
                     f'valid_loss: {valid_loss:.5f}')
        
        print(print_msg)
        
        # clear lists to track next epoch
        train_losses = []
        valid_losses = []
        
        # early_stopping needs the validation loss to check if it has decresed, 
        # and if it has, it will make a checkpoint of the current model
        early_stopping(valid_loss, model)
        
        if early_stopping.early_stop:
            print("Early stopping")
            break
        
    # load the last checkpoint with the best model
    model.load_state_dict(torch.load('checkpoint.pt'))

    return  model, avg_train_losses, avg_valid_losses

具体的完整代码为:https://github.com/Bjarten/early-stopping-pytorch/blob/master/MNIST_Early_Stopping_example.ipynb

  • 46
    点赞
  • 185
    收藏
    觉得还不错? 一键收藏
  • 13
    评论
PyTorch EarlyStopping 是一个用于在训练过程中提前停止模型训练的技术。当模型在训练过程中出现过拟合或者性能不再提升时,EarlyStopping 可以帮助我们停止训练,以避免过拟合并节省时间和计算资源。 在 PyTorch 中,我们可以通过自定义一个 EarlyStopping 类来实现这个功能。以下是一个简单的示例代码: ```python import numpy as np import torch class EarlyStopping: def __init__(self, patience=5, delta=0): self.patience = patience self.delta = delta self.best_loss = np.Inf self.counter = 0 self.early_stop = False def __call__(self, val_loss): if val_loss < self.best_loss - self.delta: self.best_loss = val_loss self.counter = 0 else: self.counter += 1 if self.counter >= self.patience: self.early_stop = True return self.early_stop ``` 在训练过程中,我们可以使用 EarlyStopping 类来监测验证集的损失值,并在满足停止条件时停止训练。例如: ```python # 创建 EarlyStopping 实例 early_stopping = EarlyStopping(patience=3) for epoch in range(num_epochs): # 训练模型 # 在验证集上计算损失值 val_loss = calculate_validation_loss(model, validation_data) # 检查是否满足停止条件 if early_stopping(val_loss): print("Early stopping") break # 继续训练 ``` 在上述示例中,`patience` 参数表示允许验证集损失连续 `patience` 个 epoch 没有下降的次数,`delta` 参数表示损失值必须至少下降 `delta` 才会被认为是有明显改进。如果连续 `patience` 次都没有达到这个改进,训练将被停止。 这就是 PyTorch EarlyStopping 的基本用法,它可以帮助我们更加高效地训练模型,并避免过拟合。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 13
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值