在载入模型参数前加上:
model = nn.DataParallel(model)
还有错误如:IncompatibleKeys(missing_keys=[], unexpected_keys=[])
或者错误:RuntimeError: module must have its parameters and buffers on device cuda:0 (device_ids[0])
再加上:
model = nn.DataParallel(model,device_ids=[0,1])
device = torch.device("cuda:0" )
model.to(device)