前言
因为有些人是使用服务器去跑程序,不是用的自己的主机,所以不可避免的出现程序被其他使用者中断或者多人共用导致内存爆掉而程序被kill,这个时候在程序中添加断点就显得尤为重要。
提示:以下是本篇文章正文内容,下面案例可供参考
一、断点是什么
其实没有那么难理解,其实原理就是在程序运行时,每训练一个epoch就把模型训练后得到的各种参数保存下来,下次训练直接加载你保存各种参数的这个模型就可以继续接着训练了。
二、步骤
1.保存模型
找到你网络中保存权值的语句,我们就在这附近添加语句即可:
#保存权值语句
if (epoch + 1) % save_period == 0 or epoch + 1 == Epoch:
torch.save(model.state_dict(), os.path.join(save_dir, "ep%03d-loss%.3f-val_loss%.3f.pth" % (epoch + 1, total_loss / epoch_step, val_loss / epoch_step_val)))
#保存中断时模型的语句,我设置的是每训练完一轮就保存一下
if (epoch + 1) % 1 == 0 or epoch + 1 == UnFreeze_Epoch:
state = {'model':model.state_dict(),'optimizer':optimizer.state_dict(),'eppoch':epoch}
torch.save(state,log_dir)
2.在训练前加载保存的模型
代码如下(示例):
test_flag = True
log_dir = '.../*.pth'
if test_flag:
checkpoint = torch.load(log_dir)
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
Init_Epoch = checkpoint['eppoch']
print('加载中断处模型成功,继续训练!')
else:
Init_Epoch = 0
print('无保存模型,从头开始训练! ')