在对网络进行训练的时候,尤其是一些结构复杂的大型网络模型,一般需要训练很长时间才能得到比较满意的结果,要是从零开始训练那就是更漫长的等待了!在训练的过程中总可能发生一些预想不到的意外从而造成训练的终止,所以学会如何在网络训练中断后,根据保存的权重文件继续训练就显得至关重要!!!!!
下面就来具体讲讲如何断点续训
一、保存模型
可以看我的这篇博客,看完后应该会明白的,如果有不明白的欢迎评论区留言,互相交流!
二、加载模型
使用 torch.load函数
进行加载
model.load_state_dict(torch.load('权重文件所在的路径', map_location='要存放的设备'))
三、保存权重文件
一般地,权重文件中需要保存:网络模型权重、优化器权重、以及epoch
,便于继续训练
我一般会设置条件保存权重文件,权重文件相对占用的空间比较大,可以设置 每间隔几个epoch 保存一次
if epoch % 5 == 0:
checkpoint = {
"model":model.state_dict(),
"optimizer":optimizer.state_dict(),
"epoch":epoch
} # 要保存的权重文件内容
# 创建保存权重文件所在的文件夹
if not os.path.isdir('./checkpoint'):
os.makedirs("./checkpoint")
# 使用 torch.save() 函数保存权重文件到指定的路径下
torch.save(checkpoint, './checkpoint/ckpt_%d.pth' % int(epoch))
# 或者写成:torch.save(checkpoint, './checkpoint/ckpt_{}.pth '.format(int(epoch)))
# torch.save(checkpoint, f'./checkpoint/ckpt_{int(epoch)}.pth')
这样就可以在checkpoint文件夹下看到每间隔5个epoch保存的权重文件
四、在断点处继续训练
注意 start_epoch
的设定,以保证再次训练时,epoch的次数匹配
start_epoch = 0 # strat_epoch的值一般默认为0,表示从头开始进行训练
# Resume = False # 用于指示是否继续训练 取值为逻辑值: False:初次训练; True:继续训练
Resume = True
# 判断是否继续训练
if Resume:
path_checkpoint = "./checkpoint/ckpt_5.pth" # 断点权重文件的所在路径
checkpoint = torch.load(path_checkpoint) # 加载保存的断点文件
model.load_state_dict(checkpoint['model']) # 加载模型可学习参数
optimizer.load_state_dict(checkpoint['optimizer']) # 加载优化器参数
start_epoch = checkpoint['epoch'] # 设置继续上次训练的epoch
最后,感谢博主
大梦冲冲冲https://blog.csdn.net/qq_37844044
的博文:
pytroch网络训练中断后,根据断点再次训练!!!_pytorch训练中断后怎么恢复-CSDN博客
给予我灵感