本博客讲解代码网址:https://github.com/eriklindernoren/PyTorch-GAN/tree/master/implementations/cyclegan
官方源码:https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix
官方源码和本博客讲解代码思路一致,本篇博客主要讲解整个流程。
但如果研究的话,推荐研究官方源码,其实也比较简单。
训练过程
1. Train Generators
loss函数:
loss_identity = (loss_id_A + loss_id_B) / 2
loss_GAN = (loss_GAN_AB + loss_GAN_BA) / 2
loss_cycle = (loss_cycle_A + loss_cycle_B) / 2
# Total loss
loss_G = loss_GAN + \
lambda_cyc * loss_cycle + \
lambda_id * loss_identity
代码:
# ------------------
# Train Generators
# ------------------
optimizer_G.zero_grad()
# Identity loss
loss_id_A = criterion_identity(G_BA(real_A), real_A)
loss_id_B = criterion_identity(G_AB(real_B), real_B)
loss_identity = (loss_id_A + loss_id_B) / 2
# GAN loss
fake_B = G_AB(real_A)
loss_GAN_AB = criterion_GAN(D_B(fake_B), valid)
fake_A = G_BA(real_B)
loss_GAN_BA = criterion_GAN(D_A(fake_A), valid)
loss_GAN = (loss_GAN_AB + loss_GAN_BA) / 2
# Cycle loss
recov_A = G_BA(fake_B)
loss_cycle_A = criterion_cycle(recov_A, real_A)
recov_B = G_AB(fake_A)
loss_cycle_B = criterion_cycle(recov_B, real_B)
loss_cycle = (loss_cycle_A + loss_cycle_B) / 2
# Total loss
loss_G = loss_GAN + \
lambda_cyc * loss_cycle + \
lambda_id * loss_identity
loss_G.backward()
optimizer_G.step()
2. Train Discriminator A
loss 函数:
loss_D_A = (loss_real + loss_fake) / 2
代码:
# -----------------------
# Train Discriminator A
# -----------------------
optimizer_D_A.zero_grad()
# Real loss
loss_real = criterion_GAN(D_A(real_A), valid)
# Fake loss (on batch of previously generated samples)
fake_A_ = fake_A_buffer.push_and_pop(fake_A)
loss_fake = criterion_GAN(D_A(fake_A_.detach()), fake)
# Total loss
loss_D_A = (loss_real + loss_fake) / 2
loss_D_A.backward()
optimizer_D_A.step()
3. Train Discriminator B
loss函数:
loss_D_B = (loss_real + loss_fake) / 2
代码:
# -----------------------
# Train Discriminator B
# -----------------------
optimizer_D_B.zero_grad()
# Real loss
loss_real = criterion_GAN(D_B(real_B), valid)
# Fake loss (on batch of previously generated samples)
fake_B_ = fake_B_buffer.push_and_pop(fake_B)
loss_fake = criterion_GAN(D_B(fake_B_.detach()), fake)
# Total loss
loss_D_B = (loss_real + loss_fake) / 2
loss_D_B.backward()
optimizer_D_B.step()
loss_D = (loss_D_A + loss_D_B) / 2