废话不多说,直接上代码
修改cycleGan中的代码如下
原代码
disc_H = Discriminator(in_channels=3).to(config.DEVICE)
disc_Z = Discriminator(in_channels=3).to(config.DEVICE)
gen_Z = Generator(img_channels=3, num_residuals=9).to(config.DEVICE)
gen_H = Generator(img_channels=3, num_residuals=9).to(config.DEVICE)
修改后的代码
disc_H = Discriminator(in_channels=3)
disc_Z = Discriminator(in_channels=3)
gen_Z = Generator(img_channels=3, num_residuals=9)
gen_H = Generator(img_channels=3, num_residuals=9)
#使用多gpu加速
if torch.cuda.device_count() > 1:
print("Let's use", torch.cuda.device_count(), "GPUs!")
disc_H = nn.DataParallel(disc_H, device_ids=[0,1])
disc_Z = nn.DataParallel(disc_Z, device_ids=[0,1])
gen_Z = nn.DataParallel(gen_Z, device_ids=[0,1])
gen_H = nn.DataParallel(gen_H, device_ids=[0,1])
disc_H.to(config.DEVICE)
disc_Z.to(config.DEVICE)
gen_Z.to(config.DEVICE)
gen_H.to(config.DEVICE)
多gpu训练得到的结果单gpu貌似也能调用,效果如何,自行测试