错误原因
这是使用nn.DataParallel产生的错误,DataParallel或DistributedDataParallel产生的错误。
从它的源码中就能看出来:
class DataParallel(Module):
def __init__(self, module, device_ids=None, output_device=None, dim=0):
super(DataParallel, self).__init__()
if not torch.cuda.is_available():
self.module = module
self.device_ids = []
return
if device_ids is None:
device_ids = list(range(torch.cuda.device_count()))
if output_device is None:
output_device = device_ids[0]
解决方法
只需要将之前的net.XXX改为:
net.module.XXX即可。
例如,将netG.state_dict()
改为:
netG.module.state_dict()