问题
在进行训练网络模型时,经常会遇到服务器中断或其他原因导致正在训练的模型中断,如果没有保存模型,就要重新训练,费时费力。这种情况怎么解决呢?可以继续原先的模型训练程度继续训练吗?包括weights,epochs,lr,loss等等…
解决
1.通过torch.save()方法保存模型,包括model,loss,epoch,IoU。可以设置每隔几个epochs保存一次。
state = {
"net":model.module.state_dict(),
"loss":val_loss,
"epoch":epoch,
"iou":lb,
}
if not os.path.isdir("checkpoint"):
os.mkdir("checkpoint")
torch.save(state,'./checkpoint/ckpt_best_%s.pth' % (str(fold)))
2.通过args.resume()方法判断是否可以加载保存的模型继续训练
if args.resume == True:
checkpoint = torch.load('./checkpoint/ckpt_best_%s.pth' % (str(fold)))
model.load_state_dict(checkpoint['net'])
best_loss = checkpoint['loss']
epoch = checkpoint['epoch'] + 1
best_iou = checkpoint['iou']