模型设置断点重新训练以及训练过程中权重的几种保存和使用方式

保存权重的几种方式

方式一

直接保存权重的state_dict

# 设置保存路径
save_path = './Lenet.pth'
# 进行权重保存 
torch.save(model.state_dict(), save_path)
# 加载权重
model.load_state_dict(torch.load(save_path))

方式二

直接保存整个权重

# 保存权重
torch.save(model, save_path)
# 加载权重
model = torch.load(save_path)

 方式三

可用于断点重新训练的包括当前训练的周期数、优化器等参数的权重保存

# 保存权重
torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            ...
            }, save_path)
# 加载
checkpoint = torch.load(save_path)
start_epoch=checkpoint['epoch']
model.load_state_dict(checkpoint['model_state_dict'])
# 如果加载权重仅用于测试,则加上下述代码
model.eval()
# 训练时
model.train()

方式三常用于下述情景:

模型在训练过程中因为某些原因中断,想要在上次训练的基础之上接着训练,则需要提前在训练

过程中保存相应的状态,具体如下:

首先要在训练过程中保存模型参数、当前的epoch数以及相应的优化器参数,因为在训练过程中相应的优化器相关参数会不断进行更新,因此不能只保存模型参数,当前训练期数epoch以及相应的优化器。

# 下述代码加在训练的周期迭代过程最后即可
# Save models checkpoints
    state_G_A2B = {
        'epoch': epoch,  # 当前训练期数
        'net': netG_A2B.state_dict(),  # 网络参数
        'optimizer': optimizer_G.state_dict(),  # 优化器相关参数
    } 
    torch.save(state_G_A2B, 'output/netG_A2B_%d.pth'%{epoch})

(1)设置参数

parser.add_argument('--resume', type=bool, default=True, help='whether to resume training')

(2)判断resume参数是否为True并进行相应的初始化

注意,在对优化器、网络参数以及epoch进行更新赋值时,再赋值前必须先有其相关定义。

if not opt.resume:
    netG_A2B.apply(weights_init_normal)
    netG_B2A.apply(weights_init_normal)
    netD_A.apply(weights_init_normal)
    netD_B.apply(weights_init_normal)

else:
    # Load state dicts
    # G_A2B保存的权重(包含epoch、权重以及优化器)
    checkpoint_G_A2B=torch.load('')
    # 加载上次训练结束时所在的epoch
    opt.epoch=checkpoint_G_A2B['epoch']
    netG_A2B.load_state_dict(checkpoint_G_A2B['net'])
    optimizer_G.load_state_dict(checkpoint_G_A2B['optimizer'],strict=False)
    # G_B2A保存的权重(包含epoch、权重以及优化器)
    checkpoint_G_B2A = torch.load('')
    # D_A保存的权重(包含epoch、权重以及优化器)
    checkpoint_D_A = torch.load('')
    netD_A.load_state_dict(checkpoint_D_A['net'])
    optimizer_D_A.load_state_dict(checkpoint_D_A['optimizer'],strict=False)
    # D_B保存的权重(包含epoch、权重以及优化器)
    checkpoint_D_B = torch.load('')
    netD_B.load_state_dict(checkpoint_D_B['net'])
    optimizer_D_B.load_state_dict(checkpoint_D_B['optimizer'],strict=False)

  • 2
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值