为什么要保存epoch?如果模型比较大,在训练时可能会由于某些意外原因导致训练还没有完全完成就终止,对付这种情况,可以通过每隔一定数量的epoch就保存一次模型参数,下次如果出现训练终止的情况时,就可以加载最新的模型来恢复训练,而不用从头开始。
在首次训练时,先对模型进行保存,这里根据自己代码的实际情况进行划分,笔者以5个epoch保存为一个checkpoint,在保存时,一定要记得保存optimizer,否则,最终的loss结果会出现与预期不相符:
for epoch in range(start_epoch, epoches):
train(train_dataloader, model, loss_func, optimizer, epoch)
test(test_dataloader, model, loss_func)
if (epoch + 1) % 5 == 0:
print('epoch:', epoch + 1)
print('learning rate:', optimizer.state_dict()['param_groups'][0]['lr'])
checkpoint = {
'model': model.state_dict(),
'optimizer': optimizer.state_dict(),
'epoch': epoch
}
torch.save(checkpoint, './checkpoint/ckpt_best_%s.pth' % (str(epoch + 1)))
print('Saved all parameters!\n')
保存的模型参数:
从上次的断点处继续训练,笔者在这里列出两种方法,第一种是通过判断RESUME来确定是否恢复训练;第二种是直接加载上次的断点路径来继续训练。
第一种:
RESUME = True #控制是否是恢复训练。False:初次训练;True:继续训练
epoches = 20 #训练了20次
start_epoch = -1
#模型断点的设置
if RESUME:
print('-----------------------------')
path_checkpoint = "./checkpoint/ckpt_best_5.pth" # 断点路径
checkpoint = torch.load(path_checkpoint) # 加载断点
model.load_state_dict(checkpoint['model']) # 加载模型可学习参数
optimizer.load_state_dict(checkpoint['optimizer']) # 加载优化器参数
start_epoch = checkpoint['epoch'] + 1 # 设置开始的epoch
print('加载 epoch {} 成功!'.format(start_epoch))
print('-----------------------------')
else:
start_epoch = 0
print('无保存模型,将从头开始训练!')
第二种:
if os.path.exists("./checkpoint/ckpt_best_10.pth"):
checkpoint = torch.load("./checkpoint/ckpt_best_10.pth")
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
start_epoch = checkpoint['epoch'] + 1
print('加载 epoch {} 成功!'.format(start_epoch))
else:
start_epoch = 0
print('无保存模型,将从头开始训练!')