废话不多说,直接上代码
修改cycleGan中的代码如下
原代码
disc_H = Discriminator(in_channels=3).to(config.DEVICE)
disc_Z = Discriminator(in_channels=3).to(config.DEVICE)
gen_Z = Generator(img_channels=3, num_residuals=9).to(config.DEVICE)
gen_H = Generator(img_channels=3, num_residuals=9).to(config.DEVICE)
修改后的代码
disc_H = Discriminator(in_channels=3)
disc_Z = Discriminator(in_channels=3)
gen_Z = Generator(img_channels=3, num_residuals=9)
gen_H = Generator(img_channels=3, num_residuals=9)
#使用多gpu加速
if torch.cuda.device_count() > 1:
print("Let's use", torch.cuda.device_count(), "GPUs!")
disc_H