pytorch discoGAN代码学习

训练过程

1. Train Generators

 

 

loss函数:

loss_GAN = (loss_GAN_AB + loss_GAN_BA) / 2
loss_pixelwise = (pixelwise_loss(fake_A, real_A) + \
                  pixelwise_loss(fake_B, real_B)) / 2
loss_cycle = (loss_cycle_A + loss_cycle_B) / 2

loss_G = loss_GAN + loss_cycle + loss_pixelwise


代码: 

        # ------------------
        #  Train Generators
        # ------------------

        optimizer_G.zero_grad()

        # GAN loss
        fake_B = G_AB(real_A)
        loss_GAN_AB = adversarial_loss(D_B(fake_B), valid)
        fake_A = G_BA(real_B)
        loss_GAN_BA = adversarial_loss(D_A(fake_A), valid)

        loss_GAN = (loss_GAN_AB + loss_GAN_BA) / 2

        # Pixelwise translation loss
        loss_pixelwise = (pixelwise_loss(fake_A, real_A) + \
                          pixelwise_loss(fake_B, real_B)) / 2

        # Cycle loss
        loss_cycle_A = cycle_loss(G_BA(fake_B), real_A)
        loss_cycle_B = cycle_loss(G_AB(fake_A), real_B)
        loss_cycle = (loss_cycle_A + loss_cycle_B) / 2

        # Total loss
        loss_G = loss_GAN + loss_cycle + loss_pixelwise

        loss_G.backward()
        optimizer_G.step()

2. Train Discriminator A

 

loss 函数:

loss_D_A = (loss_real + loss_fake) / 2

代码:

        # -----------------------
        #  Train Discriminator A
        # -----------------------

        optimizer_D_A.zero_grad()

        # Real loss
        loss_real = adversarial_loss(D_A(real_A), valid)
        # Fake loss (on batch of previously generated samples)
        loss_fake = adversarial_loss(D_A(fake_A.detach()), fake)
        # Total loss
        loss_D_A = (loss_real + loss_fake) / 2

        loss_D_A.backward()
        optimizer_D_A.step()

3. Train Discriminator B

 

loss函数:

loss_D_B = (loss_real + loss_fake) / 2

代码:

        # -----------------------
        #  Train Discriminator B
        # -----------------------

        optimizer_D_B.zero_grad()
        # Real loss
        loss_real = adversarial_loss(D_B(real_B), valid)
        # Fake loss (on batch of previously generated samples)
        loss_fake = adversarial_loss(D_B(fake_B.detach()), fake)
        # Total loss
        loss_D_B = (loss_real + loss_fake) / 2

        loss_D_B.backward()
        optimizer_D_B.step()

        loss_D = 0.5 * (loss_D_A + loss_D_B)

p.s: 和cyclegan是真的像,明天有空看看论文?

       贴张论文的图,我训练了一下,效果不是很好呀,颜色很奇怪,就不贴了。

Rows from top to bottom: (1) Real image from domain A (2) Translated image from domain A (3) Reconstructed image from domain A (4) Real image from domain B (5) Translated image from domain B (6) Reconstructed image from domain B

 

源代码网址:https://github.com/eriklindernoren/PyTorch-GAN/tree/master/implementations/discogan

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值