在之前的一篇博客中提到了Pytorch框架下,模型使用GPU多卡并行运算来解决显存不足的问题,即使用:
Pytorch框架下使用多台GPU的数据并行运算的指令
import torch
from torch.nn.parallel import DataParallel
if torch.cuda.device_count() > 1:
model = DataParallel(model) #数据并行运算
在Pytorch框架中,保存模型的方法为:
torch.save(model.state_dict(), './model.pth')
假如我们在一个新的任务中需要预先导入model的所有参数,方法为:
new_model.load_state_dict(torch.load('./model.pth'))
但运行后意料之外地出现了如下的error:
你也许会很纳闷,明new_model与model已经设置为相同结构了,为什么会报出结构不匹配的error。
让我们仔细地观察一下:
通过这里的信息我们可以基本分析出来报错的原因来自模型的字典的键(key)出现了不匹配。
通过查询资料发现,DataParallel操作会对模型进行封装且改变键名,具体过程如下:
- 使用torch.nn.DataParallel来并行训练模型时,模型实际上会被封装在DataParallel容器中,这个容器可以使模型在多个GPU上进行并行计算;
- 当你保存一个经过DataParallel训练的模型时,PyTorch会将模型的参数保存在一个有关DataParallel的容器中(这就是为什么你在加载模型时需要特别注意。如果你加载了一个带有DataParallel容器的模型参数,并且你的当前代码并没有设置DataParallel,那么PyTorch会尝试加载DataParallel容器,但是当前环境下可能没有多个GPU可用,这就导致了错误)。
解决方法
方法一:导入参数前先使用DataParallel封装
if torch.cuda.device_count() > 1:
new_model = DataParallel(new_model)
new_model.load_state_dict(torch.load('./model.pth'))
但是这样会使得新模型还是被DataParallel封装的状态,可能会带来后续的一些麻烦。因此可考虑能否解除DataParallel封装。
方法二:解除DataParallel封装
我们先来分析以下DataParallel封装导致模型的键名更改的具体操作:当模型经过torch.nn.DataParallel操作后,模型的所有键会多出一个前缀module.。这是因为DataParallel在内部对模型进行了包装,以便在多个GPU上进行并行计算。为了使模型能够在多GPU环境下工作,DataParallel在模型的键(parameters和buffers)前添加了module.前缀,见下图。
所以要解除DataParallel对模型的封装,我们只需要删除DataParallel容器即可:
checkpoint = torch.load('./model.pth')
# 删除DataParallel前缀
new_state_dict = {k.replace('module.', ''): v for k, v in checkpoint.items()}
new_model.load_state_dict(new_state_dict)
这样新模型就不再被DataParallel封装了!