DataParallel并行运算的模型导入失败Error的解决方法——解除DataParallel封装

在之前的一篇博客中提到了Pytorch框架下,模型使用GPU多卡并行运算来解决显存不足的问题,即使用: Pytorch框架下使用多台GPU的数据并行运算的指令

import torch 
from torch.nn.parallel import DataParallel


if torch.cuda.device_count() > 1:
    model = DataParallel(model)  #数据并行运算

在Pytorch框架中,保存模型的方法为:

torch.save(model.state_dict(), './model.pth')

 假如我们在一个新的任务中需要预先导入model的所有参数,方法为:

new_model.load_state_dict(torch.load('./model.pth'))

但运行后意料之外地出现了如下的error:

你也许会很纳闷,明new_model与model已经设置为相同结构了,为什么会报出结构不匹配的error。 让我们仔细地观察一下:

通过这里的信息我们可以基本分析出来报错的原因来自模型的字典的键(key)出现了不匹配。 通过查询资料发现,DataParallel操作会对模型进行封装且改变键名,具体过程如下:

  1. *使用torch.nn.DataParallel来并行训练模型时,模型实际上会被封装在DataParallel容器中,这个容器可以使模型在多个GPU上进行并行计算*;
  2. 当你保存一个经过DataParallel训练的模型时,PyTorch会将模型的参数保存在一个有关DataParallel的容器中(这就是为什么你在加载模型时需要特别注意。如果你加载了一个带有DataParallel容器的模型参数,并且你的当前代码并没有设置DataParallel,那么PyTorch会尝试加载DataParallel容器,但是当前环境下可能没有多个GPU可用,这就导致了错误)。

解决方法

方法一:导入参数前先使用DataParallel封装

if torch.cuda.device_count() > 1:
    new_model = DataParallel(new_model)  

new_model.load_state_dict(torch.load('./model.pth'))

但是这样会使得新模型还是被DataParallel封装的状态,可能会带来后续的一些麻烦。因此可考虑能否解除DataParallel封装。

方法二:解除DataParallel封装

我们先来分析以下DataParallel封装导致模型的键名更改的具体操作:当模型经过torch.nn.DataParallel操作后,模型的所有键会多出一个前缀module.。这是因为DataParallel在内部对模型进行了包装,以便在多个GPU上进行并行计算。为了使模型能够在多GPU环境下工作,DataParallel在模型的键(parameters和buffers)前添加了module.前缀,见下图。

所以要解除DataParallel对模型的封装,我们只需要删除DataParallel容器即可:

checkpoint = torch.load('./model.pth')
# 删除DataParallel前缀
new_state_dict = {k.replace('module.', ''): v for k, v in checkpoint.items()}
new_model.load_state_dict(new_state_dict)

这样新模型就不再被DataParallel封装了!

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值