楼主这两天在研究torch,思考它能不能像tf中一样有Early Stopping机制,查阅了一些资料,主要参考了这篇博客,总结一下:
实现方法
安装pytorchtools,而后直引入Early Stopping。
代码:
# 引入 EarlyStopping
from pytorchtools import EarlyStopping
import torch.utils.data as Data # 用于创建 DataLoader
import torch.nn as nn
结合伪代码进行讲解:
model = yourModel() # 伪代码
# 指定损失函数,可以是其他损失函数,根据训练要求决定
criterion = nn.CrossEntropyLoss() # 交叉熵损失函数,注意该损失函数对自动对批量样本的损失取平均
# 指定优化器,可以是其他
optimizer = torch.optim.Adam(model.parameters())
# 初始化 early_stopping 对象
patience = 20 # 当验证集损失在连续20次训练周期中都没有得到降低时,停止模型训练,以防止模型过拟合
early_stopping = EarlyStopping(patience, verbose=True) # 关于 EarlyStopping 的代码可先看博客后面的内容
batch_size = 64 # 或其他,该参数属于超参,对于如何选择超参,你可以参考下我的上一篇博客
n_epochs = 100 # 可以设置大一些,毕竟你是希望通过 early stopping 来结束模型训练
#----------------------------------------------------------------
# 训练模型,直到 epoch == n_