深度学习训练非常耗费时间,如果中间中止训练,则会非常麻烦。所以为了节省时间,通常会让模型从断点中继续训练。为此,需要模型在训练的过程中保存一些关键信息,比如模型的参数、优化器的配置,epoch等。
模型保存
def save_checkpoint(model, epoch,loss,optimizer):
model_out_path = "model/" + "model_epoch_{}.pth".format(epoch)
state = {"epoch": epoch,
"model": model,
'loss':loss,
'optimizer': optimizer.state_dict()}
# check path status
if not os.path.exists("model/"):
os.makedirs("model/")
# save model
torch.save(state, model_out_path)
print("Checkpoint saved to {}".format(model_out_path))
断点恢复
if opt.resume:
if os.path.isfile(opt.resume):
print("===> loading checkpoint: {}".format(opt.resume))
checkpoint = torch.load(opt.r