仅作为记录,大佬请跳过。
DataParallel便于gpu并行进行。引用模型可修改如下
原:
# Model
model = CleanU_Net(in_channels=1, out_channels=2)
#model = CleanU_Net()
model = torch.nn.DataParallel(model, device_ids=list(
range(torch.cuda.device_count()))).cuda()
修改后:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model=model.module