CycleGan论文阅读以及调试笔记

论文
code(pytorch部署,简易可读版本,个人调试的是这个版本)

论文阅读

论文主要提出了一个cycle-gan,实现图像风格迁移的工作。创新点在于图像迁移项目中并没有很多paired data可进行训练,unpaired图像对更多一点。但是我们又想生成图像,就是要学习一个映射:使图像X很像图像Y。由于这个mapping是非常严苛的,因此可以同时学习一个逆映射使Y非常像X。同时,引入了一个循环一致损失函数(cycle consistency loss)进行优化。该技术可以用于多项领域,泛化能力强。

准备的数据是unpaired data。可以看出paired data间存在着联系,而unpaired data中x和y之间并没有其他的信息。
在这里插入图片描述
所以目标在于学习两个映射G(x)和F(y)以及对抗性判别器,论文的图示很清晰的说明了这点。同时,为进一步将这个映射正则化,引入了2个循环一致性损失函数进行对抗性网络。图三建议与第三部分公式一起看,主要是优化目标函数。
在这里插入图片描述

网络架构

借鉴了Johnson的GNN网络架构,根据代码说明一下。loss中的参数设定在原文中的training details,如有需要自行阅读。

#n_residual_blocks: for 128*128 training images, use 6 blocks; for 256*256 training images, use 9 blocks, please feel free to set this hyper-parameter.
class Generator(nn.Module):
    def __init__(self, input_nc, output_nc, n_residual_blocks=9):
        super(Generator, self).__init__()
         
        # Initial convolution block       
        model = [   nn.ReflectionPad2d(3),
                    nn.Conv2d(input_nc, 64, 7),
                    nn.InstanceNorm2d(64),
                    nn.ReLU(inplace=True) ]

        # Downsampling
        in_features = 64
        out_features = in_features*2
        #paper content: including 3 convs,several residual, use instancenorm
        for _ in range(2):
            model += [  nn.Conv2d(in_features, out_features, 3, stride=2, padding=1),
                        nn.InstanceNorm2d(out_features),
                        nn.ReLU(inplace=True) ]
            in_features = out_features
            out_features = in_features*2

        # Residual blocks
        for _ in range(n_residual_blocks):
            model += [ResidualBlock(in_features)]

        # Upsampling, paper content: two fractionally-strided convs with stride 1/2
        out_features = in_features//2
        for _ in range(2):
            model += [  nn.ConvTranspose2d(in_features, out_features, 3, stride=2, padding=1, output_padding=1),
                        nn.InstanceNorm2d(out_features),
                        nn.ReLU(inplace=True) ]
            in_features = out_features
            out_features = in_features//2

        # Output layer, paper content: one conv maps feature to rgb
        model += [  nn.ReflectionPad2d(3),
                    nn.Conv2d(64, output_nc, 7),
                    nn.Tanh() ]

        self.model = nn.Sequential(*model)

    def forward(self, x):
        return self.model(x)

对抗性判别器部分

class Discriminator(nn.Module):
    def __init__(self, input_nc):
        super(Discriminator, self).__init__()

        # A bunch of convolutions one after another
        model = [   nn.Conv2d(input_nc, 64, 4, stride=2, padding=1),
                    nn.LeakyReLU(0.2, inplace=True) ]

        model += [  nn.Conv2d(64, 128, 4, stride=2, padding=1),
                    nn.InstanceNorm2d(128), 
                    nn.LeakyReLU(0.2, inplace=True) ]

        model += [  nn.Conv2d(128, 256, 4, stride=2, padding=1),
                    nn.InstanceNorm2d(256), 
                    nn.LeakyReLU(0.2, inplace=True) ]

        model += [  nn.Conv2d(256, 512, 4, padding=1),
                    nn.InstanceNorm2d(512), 
                    nn.LeakyReLU(0.2, inplace=True) ]

        # FCN classification layer
        model += [nn.Conv2d(512, 1, 4, padding=1)]

        self.model = nn.Sequential(*model)

    def forward(self, x):
        x =  self.model(x)
        # Average pooling and flatten
        return F.avg_pool2d(x, x.size()[2:]).view(x.size()[0], -1)

损失函数部分:

# Lossess,initialize
criterion_GAN = torch.nn.MSELoss()
criterion_cycle = torch.nn.L1Loss()
criterion_identity = torch.nn.L1Loss()
        ###### Generators A2B and B2A ######
        optimizer_G.zero_grad()

        # Identity loss
        # G_A2B(B) should equal B if real B is fed, F(G(x))≈x, therefore calculate MAE
        same_B = netG_A2B(real_B)
        loss_identity_B = criterion_identity(same_B, real_B)*5.0
        # G_B2A(A) should equal A if real A is fed
        same_A = netG_B2A(real_A)
        loss_identity_A = criterion_identity(same_A, real_A)*5.0

        # GAN loss, fake images use MSE to calculate loss
        fake_B = netG_A2B(real_A)
        pred_fake = netD_B(fake_B)
        loss_GAN_A2B = criterion_GAN(pred_fake, target_real)

        fake_A = netG_B2A(real_B)
        pred_fake = netD_A(fake_A)
        loss_GAN_B2A = criterion_GAN(pred_fake, target_real)

        # Cycle loss, for example, whether recovered A equals A, therefore also use MAE
        recovered_A = netG_B2A(fake_B)
        loss_cycle_ABA = criterion_cycle(recovered_A, real_A)*10.0

        recovered_B = netG_A2B(fake_A)
        loss_cycle_BAB = criterion_cycle(recovered_B, real_B)*10.0

        # Total loss
        loss_G = loss_identity_A + loss_identity_B + loss_GAN_A2B + loss_GAN_B2A + loss_cycle_ABA + loss_cycle_BAB
        loss_G.backward()

实验分析

在section5对实验结果进行了分析。对于loss func而言,做了消融实验,分别移除了GAN和cycle-consistency loss,然后也做了GAN+前向cycle/GAN+后向cycle。实验结果表明,这样会使训练不稳定,可以看下对比图。
很清晰地看出CycleGAN的稳定性,其他的要么细节不够,要么失真……

调试经验

这个代码也是蛮好调的,用monet2photo数据集先进行了下复现,batchsize设为1就导致训练速度还是比较慢……对数据量要求低,所以还要什么自行车呢。用自己的数据集进行训练,效果还行,但是我的数据集大小不一致,因此在test脚本中还是需要在transform_中加入resize和crop(这里换成了centercrop,还会进行下一步优化),有些图像被crop掉了不完整,后面会进行优化下做对比实验。以及code中用visdom有点鸡肋,如果不用可以注释掉。后续有进一步优化的tricks会继续补充。

  • 4
    点赞
  • 11
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值