如果有多个GPUs,可以使用nn.DataParallel封装模型,然后使用model.to(device)把模型放在GPUs上。
model = Model(input_size, output_size)
if torch.cuda.device_count() > 1:
# dim = 0 [30, xxx] -> [10, ...], [10, ...], [10, ...] on 3 GPUs
model = nn.DataParallel(model)
model.to(device)
如果有3个GPUs,batch size为30,则每个gpu batch 10个输入。
挖个坑:
看一下torch.nn.DataParallel和torch.nn.parallel.DistributedDataParallel的区别