- 一机多卡 数据并行
# 假设就一个数据
data = torch.rand([16, 10, 5])
# 前向计算要求数据都放进GPU0里面
# device = torch.device('cuda:0')
# data = data.to(device)
data = data.cuda()
# 将网络同步到多个GPU中
model_p = torch.nn.DataParalle(model.cuda(), device_ids=[0, 1, 2, 3])
logits = model_p(inputs)
# 接下来计算loss
loss = crit(logits, target)
optimizer.zero_grad()
loss.backward()
optimizer.step()
# 注:需要保证数据和初始网路都在你所选择的多个gpu中的第一块上就行。