关于存储和加载模型权重

        checkpoint = load_checkpoint(args.resume)
        model_dict = model.state_dict()
        checkpoint_load = {k: v for k, v in (checkpoint['state_dict']).items() 
                            if k in model_dict}
        model_dict.update(checkpoint_load)
        model.load_state_dict(model_dict)
        start_epoch = checkpoint['epoch']
        best_top1 = checkpoint['best_top1']
        print("=> Start epoch {}  best top1 {:.1%}".format(start_epoch, best_top1))

1.首先先读取arg.sume(已存储的权重)到checkpoint,相当于字典

2.再读取模型中的参数权重到model_dict

3.将checkpoint中key值对应model_dict的数据加载到checkpoint_load中

4.将已经训练好的模型参数更新并加载到已有模型参数中(单卡)

5.再读取checkpoint中的其他参数,以此类推

model.module.load_state_dict(checkpoint['state_dict'])

 加载模型参数(多卡)

torch.save(model.state_dict(), model_out_path)

存储模型参数(单卡)

torch.save(state, fpath)
save_checkpoint({
                'state_dict': model.module.state_dict(),
                'epoch': epoch + 1,
                'best_top1': best_top1
                })

存储模型参数(多卡)以及其他信息,

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值