使用多GPU运行Pytorch代码时,出现TypeError: 'DataParallel' object is not iterable
妈呀排查的我都要吐血了终于发现问题出在哪儿,还是我太白目。。
写的是CNN代码,写损失函数模型时,首先要把损失函数放到nn.ModuleList中,然后把它作为DataParallel的module参数传进去。
self.loss_module = nn.ModuleList()
self.loss_module.append(nn.L1Loss())
self.loss_module = nn.DataParallel(
self.loss_module, range(args.n_GPUs)
)
之后在多GPU的情况下如果要以iterable的方式获取self.loss_module中的损失函数,就得拿其中的module属性!之前我就直接return self.loss_module了。。
for l in self.get_loss_module():
...
def get_loss_module(self):
if self.n_GPUs == 1:
return self.loss_module
else:
return self.loss_module.module