单gpu
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
指定一个以上的指定gpu
device_ids=[0, 1]
if torch.cuda.device_count() > 1:
model = nn.DataParallel(model,device_ids=device_ids)
if torch.cuda.is_available():
model.cuda()
全部gpu
if torch.cuda.device_count() > 1:
model = nn.DataParallel(model)
if torch.cuda.is_available():
model.cuda()
- 注意使用多gpu时训练或测试 inputs和labels需加载到gpu中
inputs, labels = Variable(inputs.cuda()), Variable(labels.cuda())