深度学习-服务器pytorch多GPU训练踩坑,报错RuntimeError: Error(s) in loading state_dict

文章讨论了在PyTorch中使用多GPU训练模型时,由于nn.DataParallel导致的模型state_dict键名变化问题,以及两种解决方法:一是使用nn.DataParallel加载模型,二是移除state_dict中的module.前缀。这两种方法都能确保在多GPU训练后的模型权重能在测试时正确导入。
摘要由CSDN通过智能技术生成

服务器上使多块显卡训练pytorch模型,语句如下:

if torch.cuda.device_count() > 1:
    print("Let's use ", torch.cuda.device_count(), "GPUs.")     
    #net = torch.nn.DataParallel(net, device_ids=[0, 1, 2, 3])#指定GPU训练
    net = torch.nn.DataParallel(net)#使用所有可用GPU训练

保存语句:

torch.save({'epoch': epoch,
                    'step': step,
                    'net_state_dict': net.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict()},
                    cfg.CHECKPOINT_PATH)

然后在测试时,运行测试程序,导入模型加载.pth文件时会报如下类似错误:

Traceback (most recent call last):
  File "D:/code/pytorch_pose/src/test.py", line 14, in <module>
    net = which_model(is_shallow=cfg.IS_SHALLOW, net_state_dict=checkpoint['net_state_dict'])
  File "D:\code\pytorch_pose\src\model\model.py", line 319, in which_model
    return resnet18(net_state_dict)
  File "D:\code\pytorch_pose\src\model\model.py", line 309, in resnet18
    model.load_state_dict(net_state_dict)
  File "D:\anaconda_python3.9\envs\python3.6env\lib\site-packages\torch\nn\modules\module.py", line 1052, in load_state_dict
    self.__class__.__name__, "\n\t".join(error_msgs)))
RuntimeError: Error(s) in loading state_dict for ResNet:
    Missing key(s) in state_dict: "conv1.weight", "bn1.weight","layer4.0.conv2.weight", "fc.5.weight", ......."fc.5.bias". 
    Unexpected key(s) in state_dict: "module.conv1.weight", "module.layer2.1.conv1.weight", ......"module.layer3.0. 

加载权重.pth的语句:

    if net_state_dict is not None:
        print('Loading resnet18 checkpoint weight...')
        model.load_state_dict(net_state_dict)

报错原因分析如下:

在使用多GPU训时,我们使用了nn.DataParallel(model)对模型进行了包装,保存模型参数时,会在每个模型参数前添加module字段,作为模型预训练参数字典文件中的key。

因此为了保证在多GPU上训练的结果可以在测试时被正确导入,解决方法有两种:

方法一(已测试可用):

    if net_state_dict is not None:
        print('Loading resnet18 checkpoint weight...')
        model = nn.DataParallel(model) # 添加了该句后,就能正常导入在多GPU上训练的模型参数了
        model.load_state_dict(net_state_dict)

方法二:

if net_state_dict is not None:
        print('Loading resnet18 checkpoint weight...')
        model.load_state_dict({k.replace('module.', ''): v for k, v in torch.load(model_state_path)['state_dict'].items()})
        model.load_state_dict(net_state_dict)

参考链接:https://blog.csdn.net/yangzhengzheng95/article/details/88574200

  • 1
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值