首先准备好模型和数据,假设有四张显卡的话,编号:0,1,2,3
首先加载数据,一般加载到第一张显卡上:
if torch.cuda.is_available():
# model.cuda()
features = features.cuda()
然后将模型加载进来,采用如下的方式:
model = nn.DataParallel(model.cuda(),device_ids=[0,1,2,3])
首先准备好模型和数据,假设有四张显卡的话,编号:0,1,2,3
首先加载数据,一般加载到第一张显卡上:
if torch.cuda.is_available():
# model.cuda()
features = features.cuda()
然后将模型加载进来,采用如下的方式:
model = nn.DataParallel(model.cuda(),device_ids=[0,1,2,3])