训练过程
1. Train Generators
loss函数:
loss_GAN = (loss_GAN_AB + loss_GAN_BA) / 2
loss_pixelwise = (pixelwise_loss(fake_A, real_A) + \
pixelwise_loss(fake_B, real_B)) / 2
loss_cycle = (loss_cycle_A + loss_cycle_B) / 2
loss_G = loss_GAN + loss_cycle + loss_pixelwise
代码:
# ------------------
# Train Generators
# ------------------
optimizer_G.zero_grad()
# GAN loss
fake_B = G_AB(real_A)
loss_GAN_AB = adversarial_loss(D_B(fake_B), valid)
fake_A = G_BA(real_B)
loss_GAN_BA = adversarial_loss(D_A(fake_A), valid)
loss_GAN = (loss_GAN_AB + loss_GAN_BA) / 2
# Pixelwise translation loss
loss_pixelwise = (pixelwise_loss(fake_A, real_A) + \
pixelwise_loss(fake_B, real_B)) / 2
# Cycle loss
loss_cycle_A = cycle_loss(G_BA(fake_B), real_A)
loss_cycle_B = cycle_loss(G_AB(fake_A), real_B)
loss_cycle = (loss_cycle_A + loss_cycle_B) / 2
# Total loss
loss_G = loss_GAN + loss_cycle + loss_pixelwise
loss_G.backward()
optimizer_G.step()
2. Train Discriminator A
loss 函数:
loss_D_A = (loss_real + loss_fake) / 2
代码:
# -----------------------
# Train Discriminator A
# -----------------------
optimizer_D_A.zero_grad()
# Real loss
loss_real = adversarial_loss(D_A(real_A), valid)
# Fake loss (on batch of previously generated samples)
loss_fake = adversarial_loss(D_A(fake_A.detach()), fake)
# Total loss
loss_D_A = (loss_real + loss_fake) / 2
loss_D_A.backward()
optimizer_D_A.step()
3. Train Discriminator B
loss函数:
loss_D_B = (loss_real + loss_fake) / 2
代码:
# -----------------------
# Train Discriminator B
# -----------------------
optimizer_D_B.zero_grad()
# Real loss
loss_real = adversarial_loss(D_B(real_B), valid)
# Fake loss (on batch of previously generated samples)
loss_fake = adversarial_loss(D_B(fake_B.detach()), fake)
# Total loss
loss_D_B = (loss_real + loss_fake) / 2
loss_D_B.backward()
optimizer_D_B.step()
loss_D = 0.5 * (loss_D_A + loss_D_B)
p.s: 和cyclegan是真的像,明天有空看看论文?
贴张论文的图,我训练了一下,效果不是很好呀,颜色很奇怪,就不贴了。
Rows from top to bottom: (1) Real image from domain A (2) Translated image from domain A (3) Reconstructed image from domain A (4) Real image from domain B (5) Translated image from domain B (6) Reconstructed image from domain B
源代码网址:https://github.com/eriklindernoren/PyTorch-GAN/tree/master/implementations/discogan