保存权重的几种方式
方式一
直接保存权重的state_dict
# 设置保存路径
save_path = './Lenet.pth'
# 进行权重保存
torch.save(model.state_dict(), save_path)
# 加载权重
model.load_state_dict(torch.load(save_path))
方式二
直接保存整个权重
# 保存权重
torch.save(model, save_path)
# 加载权重
model = torch.load(save_path)
方式三
可用于断点重新训练的包括当前训练的周期数、优化器等参数的权重保存
# 保存权重
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
...
}, save_path)
# 加载
checkpoint = torch.load(save_path)
start_epoch=checkpoint['epoch']
model.load_state_dict(checkpoint['model_state_dict'])
# 如果加载权重仅用于测试,则加上下述代码
model.eval()
# 训练时
model.train()
方式三常用于下述情景:
模型在训练过程中因为某些原因中断,想要在上次训练的基础之上接着训练,则需要提前在训练
过程中保存相应的状态,具体如下:
首先要在训练过程中保存模型参数、当前的epoch数以及相应的优化器参数,因为在训练过程中相应的优化器相关参数会不断进行更新,因此不能只保存模型参数,当前训练期数epoch以及相应的优化器。
# 下述代码加在训练的周期迭代过程最后即可
# Save models checkpoints
state_G_A2B = {
'epoch': epoch, # 当前训练期数
'net': netG_A2B.state_dict(), # 网络参数
'optimizer': optimizer_G.state_dict(), # 优化器相关参数
}
torch.save(state_G_A2B, 'output/netG_A2B_%d.pth'%{epoch})
(1)设置参数
parser.add_argument('--resume', type=bool, default=True, help='whether to resume training')
(2)判断resume参数是否为True并进行相应的初始化
注意,在对优化器、网络参数以及epoch进行更新赋值时,再赋值前必须先有其相关定义。
if not opt.resume:
netG_A2B.apply(weights_init_normal)
netG_B2A.apply(weights_init_normal)
netD_A.apply(weights_init_normal)
netD_B.apply(weights_init_normal)
else:
# Load state dicts
# G_A2B保存的权重(包含epoch、权重以及优化器)
checkpoint_G_A2B=torch.load('')
# 加载上次训练结束时所在的epoch
opt.epoch=checkpoint_G_A2B['epoch']
netG_A2B.load_state_dict(checkpoint_G_A2B['net'])
optimizer_G.load_state_dict(checkpoint_G_A2B['optimizer'],strict=False)
# G_B2A保存的权重(包含epoch、权重以及优化器)
checkpoint_G_B2A = torch.load('')
# D_A保存的权重(包含epoch、权重以及优化器)
checkpoint_D_A = torch.load('')
netD_A.load_state_dict(checkpoint_D_A['net'])
optimizer_D_A.load_state_dict(checkpoint_D_A['optimizer'],strict=False)
# D_B保存的权重(包含epoch、权重以及优化器)
checkpoint_D_B = torch.load('')
netD_B.load_state_dict(checkpoint_D_B['net'])
optimizer_D_B.load_state_dict(checkpoint_D_B['optimizer'],strict=False)