写在前面:写该博客是为了记录自己这段时间里的学习收获,重在思路,若有疑问,欢迎评论,若发现该博客有错误,欢迎指出。这两天终于有了空闲时间,所以就多写点。该博客使用了这个 g i t h u b github github 仓库中提供的 p y t o r c h _ t o o l s pytorch\_tools pytorch_tools,而且该仓库中还有这个工具的使用案例,建议读者前往查看。在此感谢作者的代码分享。
注意:在评论中我发现有读者使用 pip install pytorchtools
获取相关代码,这是不对的(这是我在文中没有描述清楚造成,不好意思),读者们可以直接去上述的仓库里把其中的 pytorchtools.py
下载(或复制)下来并放置在项目中,这样就可以正常执行 from pytorchtools import EarlyStopping
并使用
E
a
r
l
y
S
t
o
p
p
i
n
g
EarlyStopping
EarlyStopping 了。
实现
有了
p
y
t
o
r
c
h
_
t
o
o
l
s
pytorch\_tools
pytorch_tools 工具后,使用
e
a
r
l
y
s
t
o
p
p
i
n
g
early\ stopping
early stopping 就很简单了。
先从该工具类中导入 E a r l y S t o p p i n g EarlyStopping EarlyStopping .
# import EarlyStopping
from pytorchtools import EarlyStopping
import torch.utils.data as Data # 用于创建 DataLoader
import torch.nn as nn
为了方便描述,这里还是会使用一些伪代码,如果你想阅读详细案例的话,不用犹豫直接看上述工具自带的案例代码。
model = yourModel() # 伪
# 指定损失函数,可以是其他损失函数,根据训练要求决定
criterion = nn.CrossEntropyLoss() # 交叉熵,注意该损失函数对自动对批量样本的损失取平均
# 指定优化器,可以是其他
optimizer = torch.optim.Adam(model.parameters())
# 初始化 early_stopping 对象
patience = 20 # 当验证集损失在连续20次训练周期中都没有得到降低时,停止模型训练,以防止模型过拟合
early_stopping = EarlyStopping(patience, verbose=True) # 关于 EarlyStopping 的代码可先看博客后面的内容
batch_size = 64 # 或其他,该参数属于超参,对于如何选择超参,你可以参考下我的上一篇博客
n_epochs = 100 # 可以设置大一些,毕竟你是希望通过 early stopping 来结束模型训练
#----------------------------------------------------------------
# 训练模型,直到 epoch == n_epochs 或者触发 early_stopping 结束训练
for epoch in range(1, n_epochs + 1):
# 建立训练数据的 DataLoader
training_dataset = Data.TensorDataset(X_train, y_train)
# 把dataset放到DataLoader中
data_loader = Data.DataLoader(
dataset=training_dataset,
batch_size=batch_size, # 批量大小
shuffle=True # 是否打乱数据顺序
)
#---------------------------------------------------
model.train() # 设置模型为训练模式
# 按小批量训练
for batch, (data, target) in enumerate(data_loader):
optimizer.zero_grad() # 清楚所有参数的梯度
output = model(data) # 输出模型预测值
loss = criterion(output, target) # 计算损失
loss.backward() # 计算损失对于各个参数的梯度
optimizer.step() # 执行单步优化操作:更新参数
#----------------------------------------------------
model.eval() # 设置模型为评估/测试模式
# 一般如果验证集不是很大的话,模型验证就不需要按批量进行了,但要注意输入参数的维度不能错
valid_output = model(X_val)
valid_loss = criterion(valid_output, y_val) # 注意这里的输入参数维度要符合要求,我这里为了简单,并未考虑这一点
early_stopping(valid_loss, model)
# 若满足 early stopping 要求
if early_stopping.early_stop:
print("Early stopping")
# 结束模型训练
break
# 获得 early stopping 时的模型参数
model.load_state_dict(torch.load('checkpoint.pt'))
以下是 p y t o r c h _ t o o l s pytorch\_tools pytorch_tools 工具的代码:
import numpy as np
import torch
class EarlyStopping:
"""Early stops the training if validation loss doesn't improve after a given patience."""
def __init__(self, patience=7, verbose=False, delta=0):
"""
Args:
patience (int): How long to wait after last time validation loss improved.
Default: 7
verbose (bool): If True, prints a message for each validation loss improvement.
Default: False
delta (float): Minimum change in the monitored quantity to qualify as an improvement.
Default: 0
"""
self.patience = patience
self.verbose = verbose
self.counter = 0
self.best_score = None
self.early_stop = False
self.val_loss_min = np.Inf
self.delta = delta
def __call__(self, val_loss, model):
score = -val_loss
if self.best_score is None:
self.best_score = score
self.save_checkpoint(val_loss, model)
elif score < self.best_score + self.delta:
self.counter += 1
print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
if self.counter >= self.patience:
self.early_stop = True
else:
self.best_score = score
self.save_checkpoint(val_loss, model)
self.counter = 0
def save_checkpoint(self, val_loss, model):
'''Saves model when validation loss decrease.'''
if self.verbose:
print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model ...')
torch.save(model.state_dict(), 'checkpoint.pt') # 这里会存储迄今最优模型的参数
self.val_loss_min = val_loss
结束
总结完了,不过这些代码还是需要读者根据自己的模型做出改动。希望这篇博客对你会有所帮助,再一次,欢迎指出错误。