pytorch报错AssertionError: Invalid device id
代码:
model = torch.nn.DataParallel(model, device_ids=[5,7]).cuda()
因为我想使用两个GPU,但是出现报错,原因是因为pytorch默认使用cuda=0的GPU,当gpu编号为device:0的设备被占用时,指定其他编号gpu使用torch.nn.DataParallel(model, device_ids=[5, 7])指定gpu编号会出现AssertionError: Invalid device id错误
因此加一行代码修改为
torch.cuda.set_device(5)
model = torch.nn.DataParallel(model, device_ids=[5,7]).cuda()