DataParallel的基本使用方法很简单,只需设置device_ids即可,如下所示:
device_ids = [0, 1, 2, 3]
model = torch.nn.DataParallel(model, device_ids=device_ids)
device_ids为你要使用的GPU号。如果你未使用DataParallel之前用的便是单GPU进行训练,那么对于数据不需要额外的操作,否则,你需要将模型的输入数据转移到cuda上,如:
# 此处device与device_ids无关,你可以设置device = torch.device("cuda:0")
input = input.to(device)
如果顺利的话,简单的两步就可以实现加速了。
然而,由墨菲定律可得:凡是可能出错的事就一定会出错。常见问题如下。
问题1:如果model里定义了一个函数,如初始化函数init_hidden等,并已实现DataParallel,在train函数里该如何调用?
class Model(nn.Module):
def __init__(self, ):
pass
def forward(self, ):
pass
def init_hidden(self, ):
pass
model = Model()
model = nn.DataParallel(model, device_ids=devic