解决RuntimeError: module must have its parameters and buffers on device cuda:0 (device_ids[0]) but found one of them on device: cuda:1问题
在加载多GPU模型后使用pytorch的DataParallel()时出现以上报错,网上查找了很多资料都没有正确的解决方法。
个人猜想原因在于多GPU模型的参数保存在多卡上,而DataParrallel的forward源码中规定(如下所示),model和inputs都必须在GPU: 0上,因此该出现报错
def forward(self, *inputs, **kwargs):
if not self.device_ids:
return self.module(*inputs, **kwargs)
for t in chain(self.module.parameters(), self.module.buffers()):
if t.device != self.src_device_obj:
raise RuntimeError("module must have its parameters and buffers "
"on device {} (device_ids[0]) but found one of "
"them on device: {}".format(self.src_device_obj, t.device))
inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
if len(self.device_ids) == 1:
return self.module(*inputs[0], **kwargs[0])
replicas = self.replicate(self.module, self.device_ids[:len(inputs)])
outputs = self.parallel_apply(replicas, inputs, kwargs)
return self.gather(outputs, self.output_device)
解决方法:
保存时使用参数保存,而非模型保存。保存参数也是pytorch更推荐的一种方法。
# 保存参数
torch.save(model_ft.module.state_dict(), 'params.pkl')
# 保存整个模型
torch.save(model_ft, 'net.pkl')