pytorch加载预训练权重

resume_path 为 checkpoint.pth 的文件路径

checkpoint = torch.load(resume_path, map_location=torch.device(‘cpu’))

修改训练开始轮数

args.start_epoch = checkpoint['epoch']

获取预训练权重的参数

new_param = checkpoint['state_dict']

加载模型参数

model.load_state_dict(new_param)

加载优化器

optimizer.load_state_dict(checkpoint['optimizer'])
代码来源于DCP模型(cvpr22)的源代码
################### args.resume为checkpoint.pth文件 ###################
if args.resume:
    resume_path = osp.join(args.snapshot_path, args.resume)
    if os.path.isfile(resume_path):
        if main_process():
            logger.info("=> loading checkpoint '{}'".format(resume_path))
        ################### 加载预训练权重 ###################
        checkpoint = torch.load(resume_path, map_location=torch.device('cpu'))
        ################### 修改训练开始轮数 ###################
        args.start_epoch = checkpoint['epoch']
        ################### 获取预训练权重的参数 ###################
        new_param = checkpoint['state_dict']
        try: 
        	##################### 加载模型参数 ###################
            model.load_state_dict(new_param)
        except RuntimeError:                   # 1GPU loads mGPU model
            for key in list(new_param.keys()):
                new_param[key[7:]] = new_param.pop(key)
            model.load_state_dict(new_param)
        ##################### 加载优化器###################
        optimizer.load_state_dict(checkpoint['optimizer'])
        if main_process():
            logger.info("=> loaded checkpoint '{}' (epoch {})".format(resume_path, checkpoint['epoch']))
    else:
        if main_process():       
            logger.info("=> no checkpoint found at '{}'".format(resume_path))
  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值