在torch.nn.DataParallel
#多gpu训练
device_ids = [0,1]
....
model=resnetv2sn18()
if use_gpu and len(device_ids)>1:#多gpu训练
model = model.cuda(device_ids[0])
model = nn.DataParallel(model, device_ids=device_ids)
if use_gpu and len(device_ids)==1:#单gpu训练
model = model.cuda(device_ids[0])