训练中难免遇到停电、宕机等情况。如果没有设置断点续训,那么就得从头开始,非常浪费时间和能源。
要实现断点续训,首先要保存之前训练的模型,除此之外,还有优化器的状态。这是因为目前很多优化器的参数是会随着训练而不断变化的,具体可见这篇博客:pytorch学习系列(4):常用优化算法
首先设置两个参数:
start_epoch:续训时开始训练的epoch数
resume:是否实施断点续训
start_epoch = 0
resume = True
模型和优化器设置完成后,加入如下代码,载入checkpoint再进行训练,
if opt.resume:
if os.path.isfile('checkpoint'):
checkpoint = torch.load('checkpoint')
start_epoch = checkpoint['epoch'] + 1
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
print("=> loaded checkpoint (epoch {})".format(checkpoint['epoch']))
else:
print("=> no checkpoint found")
每一轮训练保存模型后,加上如下代码保存epoch、model和optimizer的信息到checkpoint文件中。
checkpoint = {
'epoch': epoch,
'model': model.state_dict(),
'optimizer': optimizer.state_dict(),
}
torch.save(checkpoint,'checkpoint')
这种方式也支持tensorboardX可视化的断点续图,只不过会多产生一个日志文件,只记录了断点后的情况。原来的日志文件中还会继续记录断点后的数据。