Cycle-GAN阅读笔记

Summary

Cycle-GAN是图片域与域之间的生成对抗网络,它能够在不同风格的图片中相互转换。相比较于Pix2pix-GAN,其最大的优点是对数据集的要求低,不需要成对的数据集便可以实现不同图片域的转换,这大大降低了生成对抗网络对于数据集的要求。

Contribution

  • 传统的训练都是利用图片对进行训练,存在的问题就是图片对不好收集,比如从真实图片转换到MoNet风格的图片,论文提出了circle结构,实现了图片域<->图片域的转换

Related Work

  • 生成对抗网络
  • 图像到图像翻译
  • 未配对的图像到图像翻译
  • 周期一致性
  • 神经风格转换

Approach and Model Architecture

Cycle-GAN的架构
在这里插入图片描述
其原理是对于两个图片域X, Y,训练G, F使得G(X)=Y且F(Y)=X,即完成了两个图片域的相互转换,训练损失即使得 D X D_X DX区分不出F(y)和X, D Y D_Y DY区分不出G(x)和Y (这里使用了least-squares loss)。另外为了确保单一映射,作者在对抗损失之外还提出了循环一致性损失,即F(G(x))与x的一范式最小,G(F(y))和y也一样。

将这两者的Loss线性累加,最终给出的损失函数为
在这里插入图片描述
其中
在这里插入图片描述在这里插入图片描述
作者采用的网络结构,采用了Johnson等人的生成网络的体系结构,对于鉴别器网络,作者使用了70×70个PatchGAN,这也是pix2pix使用的鉴别器。

训练方面,作者使用了两个小trick,第一个是对于对抗损失,作者使用了least-squares loss;第二个是,为了使得鉴别器不至于能力太超前,作者设置了图像缓冲区,即此刻的鉴别器识别的的是生成器生成的向前第n张图(作者为此设置了50图像的缓冲区),这样能够保证生成器和鉴别器能力相适应。

Comparison against baselines

为了评估Circle-GAN的生成效果,作者遵循了Pix2pix-GAN的评估方法,在不同任务不同指标情况下对生成图片进行评估,以下是评估结果

在这里插入图片描述
当然应该着重比较Pix2pix-GAN与Circle-GAN,因为他们的生成图片具有相同的意义,当然由于Pix2pix-GAN属于图像对的学习,所以理论上Circle-GAN性能更差一点也情有可原。

另一些数值指标

在这里插入图片描述
作者还研究了Circle-GAN中使用的Loss对最终生成图片评分的影响,结果发现对抗损失和循环一致性损失都起着很关键的作用,缺一不可。另外如果只训练一个方向,例如只训练 L ( G , D Y , X , Y ) \mathcal{L}(G,D_Y,X,Y) L(G,DY,X,Y),会常引起训练不稳定性并导致模式崩溃,特别是对于没有训练的那个方向。
在这里插入图片描述
为了进一步直观地说明Loss的缺失对生成图片的影响,作者给出另一个实验结果,图中可知缺失某些Loss对最后的生成缺失产生了一些影响。

在这里插入图片描述

作者还测试了一下图像的重构,即 F ( G ( x ) ) F(G(x)) F(G(x)),并将其与原来的图片x进行比较,结果发现重构的图片大致能够恢复出输入图片的细节,表明域与域之间的单一映射较为准确。

在这里插入图片描述
接着作者进行了Pix2pix-GAN一样的实验,结论是Circle-GAN能产生接近Pix2pix的效果,这其实是挺难得的了,毕竟Circle-GAN是没有配对监督的方法。

在这里插入图片描述
我们注意到培训数据的翻译通常比测试数据的翻译更具吸引力,更多的实验发布在https://junyanz.github.io/CycleGAN/中,里面的实验都很有趣。

在某些情况下,给GAN加上额外的损失十分有必要,比如在关于艺术图像转变的示例中,可以增加一个L1范式作为损失,确保较大程度保留颜色信息。
在这里插入图片描述
更多艺术风格的转换图
在这里插入图片描述
下一个实验是数据增强:增强景深
在这里插入图片描述
艺术图像转换+1
在这里插入图片描述
不同方法在图像转换中的比较,Circle-GAN取得了领先的水平。
在这里插入图片描述
虽然Circle-GAN能在大多数图像转换的任务中表现出良好的效果,但是在一些特定的任务中,它也存在着很多不足。在左边猫狗转换中,Circle-GAN只能进行微小的改变,而在右边马到斑马的转化中,由于数据集的问题,人也染上了斑马的条纹。
在这里插入图片描述
我们还观察到配对训练数据可达到的结果与未配对方法可达到的结果之间仍存在差距。在某些情况下,这种差距可能很难消除,甚至无法消除:例如,我们的方法有时会在photos→labels任务的输出中置换用于树和建筑物的标签。要解决这种歧义,可能需要某种形式的弱语义监督。集成弱或半监督的数据可能会导致翻译器的功能更加强大,但其成本仅为全监督系统的注释成本的一小部分。

最后一些实验图
在这里插入图片描述在这里插入图片描述
在这里插入图片描述

Code

Cycle-GAN的代码主要参考了https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix#custom-model-and-dataset
给出最关键训练的部分代码如下,做个了总结

  • 该Cycle-GAN的训练部分
    • 训练鉴别器
      • 对抗损失 ( M S E ? MSE? MSE?)
        • 对于 X X X域到 Y Y Y域,训练鉴别器 D Y D_Y DY,使得 D Y D_Y DY能判别 G X ( x ) G_X(x) GX(x)为假, y y y为真
        • 对于 Y Y Y域到 X X X域,训练鉴别器 D X D_X DX,使得 D X D_X DX能判别 G Y ( y ) G_Y(y) GY(y)为假, x x x为真
    • 训练生成器
      • 对抗损失 ( M S E ? MSE? MSE?)
        • 对于 X X X域到 Y Y Y域,训练生成器 G X G_X GX,使得 D Y D_Y DY能判别 G X ( x ) G_X(x) GX(x)为真
        • 对于 Y Y Y域到 X X X域,训练生成器 G Y G_Y GY,使得 D X D_X DX能判别 G Y ( y ) G_Y(y) GY(y)为真
      • 循环一致性损失 ( L 1 L1 L1)
        • 对于 X X X域到 X X X域,训练生成器 G X G_X GX G Y G_Y GY,使得 G Y ( G X ( x ) ) G_Y(G_X(x)) GY(GX(x))接近于 x x x
        • 对于 Y Y Y域到 Y Y Y域,训练生成器 G Y G_Y GY G X G_X GX,使得 G X ( G Y ( y ) ) G_X(G_Y(y)) GX(GY(y))接近于 y y y
      • 保留原图像信息损失 ( L 1 L1 L1) (可选)
        • 对于 X X X域到 Y Y Y域,训练生成器 G X G_X GX,使得 G X ( x ) G_X(x) GX(x)接近于 x x x
        • 对于 Y Y Y域到 X X X域,训练生成器 G Y G_Y GY,使得 G Y ( y ) G_Y(y) GY(y)接近于 y y y
        if self.isTrain:
            if opt.lambda_identity > 0.0:  # only works when input and output images have the same number of channels
                assert(opt.input_nc == opt.output_nc)
            self.fake_A_pool = ImagePool(opt.pool_size)  # create image buffer to store previously generated images
            self.fake_B_pool = ImagePool(opt.pool_size)  # create image buffer to store previously generated images
            # define loss functions
            self.criterionGAN = networks.GANLoss(opt.gan_mode).to(self.device)  # define GAN loss.
            self.criterionCycle = torch.nn.L1Loss()
            self.criterionIdt = torch.nn.L1Loss()
            # initialize optimizers; schedulers will be automatically created by function <BaseModel.setup>.
            self.optimizer_G = torch.optim.Adam(itertools.chain(self.netG_A.parameters(), self.netG_B.parameters()), lr=opt.lr, betas=(opt.beta1, 0.999))
            self.optimizer_D = torch.optim.Adam(itertools.chain(self.netD_A.parameters(), self.netD_B.parameters()), lr=opt.lr, betas=(opt.beta1, 0.999))
            self.optimizers.append(self.optimizer_G)
            self.optimizers.append(self.optimizer_D)

    def set_input(self, input):
        """Unpack input data from the dataloader and perform necessary pre-processing steps.

        Parameters:
            input (dict): include the data itself and its metadata information.

        The option 'direction' can be used to swap domain A and domain B.
        """
        AtoB = self.opt.direction == 'AtoB'
        self.real_A = input['A' if AtoB else 'B'].to(self.device)
        self.real_B = input['B' if AtoB else 'A'].to(self.device)
        self.image_paths = input['A_paths' if AtoB else 'B_paths']

    def forward(self):
        """Run forward pass; called by both functions <optimize_parameters> and <test>."""
        self.fake_B = self.netG_A(self.real_A)  # G_A(A)
        self.rec_A = self.netG_B(self.fake_B)   # G_B(G_A(A))
        self.fake_A = self.netG_B(self.real_B)  # G_B(B)
        self.rec_B = self.netG_A(self.fake_A)   # G_A(G_B(B))

    def backward_D_basic(self, netD, real, fake):
        """Calculate GAN loss for the discriminator

        Parameters:
            netD (network)      -- the discriminator D
            real (tensor array) -- real images
            fake (tensor array) -- images generated by a generator

        Return the discriminator loss.
        We also call loss_D.backward() to calculate the gradients.
        """
        # Real
        pred_real = netD(real)
        loss_D_real = self.criterionGAN(pred_real, True)
        # Fake
        pred_fake = netD(fake.detach())
        loss_D_fake = self.criterionGAN(pred_fake, False)
        # Combined loss and calculate gradients
        loss_D = (loss_D_real + loss_D_fake) * 0.5
        loss_D.backward()
        return loss_D

    def backward_D_A(self):
        """Calculate GAN loss for discriminator D_A"""
        fake_B = self.fake_B_pool.query(self.fake_B)
        self.loss_D_A = self.backward_D_basic(self.netD_A, self.real_B, fake_B)

    def backward_D_B(self):
        """Calculate GAN loss for discriminator D_B"""
        fake_A = self.fake_A_pool.query(self.fake_A)
        self.loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A)

    def backward_G(self):
        """Calculate the loss for generators G_A and G_B"""
        lambda_idt = self.opt.lambda_identity
        lambda_A = self.opt.lambda_A
        lambda_B = self.opt.lambda_B
        # Identity loss
        if lambda_idt > 0:
            # G_A should be identity if real_B is fed: ||G_A(B) - B||
            self.idt_A = self.netG_A(self.real_B)
            self.loss_idt_A = self.criterionIdt(self.idt_A, self.real_B) * lambda_B * lambda_idt
            # G_B should be identity if real_A is fed: ||G_B(A) - A||
            self.idt_B = self.netG_B(self.real_A)
            self.loss_idt_B = self.criterionIdt(self.idt_B, self.real_A) * lambda_A * lambda_idt
        else:
            self.loss_idt_A = 0
            self.loss_idt_B = 0

        # GAN loss D_A(G_A(A))
        self.loss_G_A = self.criterionGAN(self.netD_A(self.fake_B), True)
        # GAN loss D_B(G_B(B))
        self.loss_G_B = self.criterionGAN(self.netD_B(self.fake_A), True)
        # Forward cycle loss || G_B(G_A(A)) - A||
        self.loss_cycle_A = self.criterionCycle(self.rec_A, self.real_A) * lambda_A
        # Backward cycle loss || G_A(G_B(B)) - B||
        self.loss_cycle_B = self.criterionCycle(self.rec_B, self.real_B) * lambda_B
        # combined loss and calculate gradients
        self.loss_G = self.loss_G_A + self.loss_G_B + self.loss_cycle_A + self.loss_cycle_B + self.loss_idt_A + self.loss_idt_B
        self.loss_G.backward()

    def optimize_parameters(self):
        """Calculate losses, gradients, and update network weights; called in every training iteration"""
        # forward
        self.forward()      # compute fake images and reconstruction images.
        # G_A and G_B
        self.set_requires_grad([self.netD_A, self.netD_B], False)  # Ds require no gradients when optimizing Gs
        self.optimizer_G.zero_grad()  # set G_A and G_B's gradients to zero
        self.backward_G()             # calculate gradients for G_A and G_B
        self.optimizer_G.step()       # update G_A and G_B's weights
        # D_A and D_B
        self.set_requires_grad([self.netD_A, self.netD_B], True)
        self.optimizer_D.zero_grad()   # set D_A and D_B's gradients to zero
        self.backward_D_A()      # calculate gradients for D_A
        self.backward_D_B()      # calculate graidents for D_B
        self.optimizer_D.step()  # update D_A and D_B's weights
  • 0
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值