PyTorch支持使用多张显卡进行训练。有两种常见的方法可以实现这一点:
- 使用
torch.nn.DataParallel
封装模型,然后使用多张卡进行并行计算。例如:
import torch
import torch.nn as nn
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# 定义模型
model = MyModel()
# 将模型放在多张卡上
if torch.cuda.device_count() > 1:
print("使用{}张卡".format(torch.cuda.device_count()))
model = nn.DataParallel(model)
model.to(device)
# 训练模型
for data in dataloader:
# 放到设备上
inputs, labels = data[0].to(device), data[1].to(device)
# 前向计算
outputs = model(inputs)
# 计算损失
loss = criterion(outputs, labels)
# 反向传播和优化
optimizer.zero_grad()
loss.backward()
optimizer.step()