[Pytorch系列-75]:生成对抗网络GAN - 图像生成开源项目pytorch-CycleGAN-and-pix2pix - CycleGAN网络结构与代码实现详解

作者主页(文火冰糖的硅基工坊):文火冰糖(王文兵)的博客_文火冰糖的硅基工坊_CSDN博客

本文网址:https://blog.csdn.net/HiWangWenBing/article/details/122106440


目录

第1章 网络的定义

1.1 网络结构

1.2 代码来源

1.3 网络结构代码解读

1.4 输入数据集处理代码解读

1.5 前向运算

第2章 网络的训练

1.1 G生成网络的结构与代码解读

1.2 D-A判决网络的结构与代码解读

1.3 D-A判决网络的结构与代码解读

1.4 网络整体的优化算法



第1章 网络的定义

1.1 网络结构

  • 相对于基础型的GAN网络,CycleGAN增加了一个核心的还原网络,导致相关的训练也跟着发生了相应的变化,因此还原网络是核心。
  • 还原是双向的,不仅仅是真实输入图片-》Fake图片-》真实输入图片的还原。还包括真实的输出图片 -> Fake图片 -》真实的输出图片的还原。
  • CycleGAN一共有4个网络:G_A2B, D_A2B, G_B2A, D_B2A, 后两个是新增的 。

1.2 代码来源

pytorch-CycleGAN-and-pix2pix\models\cycle_gan_model.py

1.3 网络结构代码解读

    def __init__(self, opt):
        """Initialize the CycleGAN class.

        Parameters:
            opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
        """
        BaseModel.__init__(self, opt)
        # specify the training losses you want to print out. The training/test scripts will call <BaseModel.get_current_losses>
        self.loss_names = ['D_A', 'G_A', 'cycle_A', 'idt_A', 'D_B', 'G_B', 'cycle_B', 'idt_B']

        # specify the images you want to save/display. The training/test scripts will call <BaseModel.get_current_visuals>
        # 真实输入图片 =》 生成图片 =》 真实输入图片的恢复图片,这是成组图片
        visual_names_A = ['real_A', 'fake_B', 'rec_A']
        # 真实的输出图片 =》 生成图片 =》 真实输出图片的恢复图片,这是成组图片
        visual_names_B = ['real_B', 'fake_A', 'rec_B']
        
        # if identity loss is used, we also visualize idt_B=G_A(B) ad idt_A=G_A(B)
        if self.isTrain and self.opt.lambda_identity > 0.0:  
            visual_names_A.append('idt_B')
            visual_names_B.append('idt_A')

        # combine visualizations for A and B
        self.visual_names = visual_names_A + visual_names_B 

        # specify the models you want to save to the disk. 
        if self.isTrain:
            # 训练模式,定义4个网络其中G_B和D_B是新增
            self.model_names = ['G_A', 'G_B', 'D_A', 'D_B']
        else:  # during test time, only load Gs
            # 测试模式,仅仅需要生成网络,其中G_B是新增。
            self.model_names = ['G_A', 'G_B']

        # define networks (both Generators and discriminators)
        # The naming is different from those used in the paper.
        # Code (vs. paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X)
        # 定义G_A和G_B网络
        self.netG_A = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.norm,
                                        not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids)
        self.netG_B = networks.define_G(opt.output_nc, opt.input_nc, opt.ngf, opt.netG, opt.norm,
                                        not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids)

        # 训练模式下,定义D_A和D_B网络。
        if self.isTrain:  # define discriminators
            self.netD_A = networks.define_D(opt.output_nc, opt.ndf, opt.netD,
                                            opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids)
            self.netD_B = networks.define_D(opt.input_nc, opt.ndf, opt.netD,
                                            opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids)

        if self.isTrain:
             # only works when input and output images have the same number of channels
            if opt.lambda_identity > 0.0: 
                assert(opt.input_nc == opt.output_nc)

            # create image buffer to store previously generated images
            self.fake_A_pool = ImagePool(opt.pool_size)  
            # create image buffer to store previously generated images
            self.fake_B_pool = ImagePool(opt.pool_size)  

            # define loss functions: 定义所用到的loss函数, 
            # 这里有三种loss,对应三种loss
            # define GAN loss.MSE, MSE loss,用于计算D网络的判决结果的loss
            self.criterionGAN = networks.GANLoss(opt.gan_mode).to(self.device)  
            # L1 Loss:图片还原程度的loss
            self.criterionCycle = torch.nn.L1Loss()
            # L1 Loss:图片转换后的损失度loss
            self.criterionIdt   = torch.nn.L1Loss()
            
            # 定义所用到的优化器
            # initialize optimizers; schedulers will be automatically created by function <BaseModel.setup>.

            # 定义G网络的优化器:优化的参数包括G_A和G_B网络的参数。
            self.optimizer_G = torch.optim.Adam(itertools.chain(self.netG_A.parameters(), self.netG_B.parameters()), lr=opt.lr, betas=(opt.beta1, 0.999))

            # 定义D网络的优化器:优化的参数包括D_A和D_B网络的参数。
            self.optimizer_D = torch.optim.Adam(itertools.chain(self.netD_A.parameters(), self.netD_B.parameters()), lr=opt.lr, betas=(opt.beta1, 0.999))

            self.optimizers.append(self.optimizer_G)
            self.optimizers.append(self.optimizer_D)
  • train模式下需要定义G_A2B, D_A2B, G_B2A, D_B2A网络, 而在测试或预测模式下,只需要定义G_A2B和G_B2A网络。
  • 只有在训练模式下,才需要定义loss和优化算法。

1.4 输入数据集处理代码解读

    def set_input(self, input):
        """Unpack input data from the dataloader and perform necessary pre-processing steps.

        Parameters:
            input (dict): include the data itself and its metadata information.

        The option 'direction' can be used to swap domain A and domain B.
        """
        AtoB = self.opt.direction == 'AtoB'
        self.real_A = input['A' if AtoB else 'B'].to(self.device)
        self.real_B = input['B' if AtoB else 'A'].to(self.device)
        self.image_paths = input['A_paths' if AtoB else 'B_paths']
  • real_A: 真实的输入图片。
  • real_B: 真实的标签图片 (标签不一定是分类的数值,也可以是一张图片)

1.5 前向运算

    def forward(self):
        """Run forward pass; called by both functions <optimize_parameters> and <test>."""
        self.fake_B = self.netG_A(self.real_A)  # G_A(A)
        self.rec_A = self.netG_B(self.fake_B)   # G_B(G_A(A))
        self.fake_A = self.netG_B(self.real_B)  # G_B(B)
        self.rec_B = self.netG_A(self.fake_A)   # G_A(G_B(B))
  • 前向运算只使用G网络,创造或生成图片。
  • 这里有4种前向运算,因此没有结果返回,生成结果存放在4个成员变量中。

第2章 网络的训练

1.1 G生成网络的结构与代码解读

(1)G网络的训练架构

 (2)G网络Loss代码实现

    def backward_G(self):
        """Calculate the loss for generators G_A and G_B"""
        lambda_idt = self.opt.lambda_identity
        lambda_A = self.opt.lambda_A
        lambda_B = self.opt.lambda_B
        # Identity loss
        if lambda_idt > 0:
            # G_A should be identity if real_B is fed: ||G_A(B) - B||
            self.idt_A = self.netG_A(self.real_B)
            self.loss_idt_A = self.criterionIdt(self.idt_A, self.real_B) * lambda_B * lambda_idt
            
            # G_B should be identity if real_A is fed: ||G_B(A) - A||
            self.idt_B = self.netG_B(self.real_A)
            self.loss_idt_B = self.criterionIdt(self.idt_B, self.real_A) * lambda_A * lambda_idt
        else:
            self.loss_idt_A = 0
            self.loss_idt_B = 0

        # GAN loss D_A(G_A(A))
        self.loss_G_A = self.criterionGAN(self.netD_A(self.fake_B), True)

        # GAN loss D_B(G_B(B))
        self.loss_G_B = self.criterionGAN(self.netD_B(self.fake_A), True)

        # Forward cycle loss || G_B(G_A(A)) - A||
        self.loss_cycle_A = self.criterionCycle(self.rec_A, self.real_A) * lambda_A

        # Backward cycle loss || G_A(G_B(B)) - B||
        self.loss_cycle_B = self.criterionCycle(self.rec_B, self.real_B) * lambda_B

        # combined loss and calculate gradients
        self.loss_G = self.loss_G_A + self.loss_G_B + self.loss_cycle_A + self.loss_cycle_B + self.loss_idt_A + self.loss_idt_B
        self.loss_G.backward()

疑问:上述代码中的loss_idt_A与loss_idt_B为什么是如下的公式?

        if lambda_idt > 0:
            # G_A should be identity if real_B is fed: ||G_A(B) - B||
            self.idt_A = self.netG_A(self.real_B)
            self.loss_idt_A = self.criterionIdt(self.idt_A, self.real_B) * lambda_B * lambda_idt
             # 疑问:为什么real_B与real_B转换后的图片idt_A相比?            


            # G_B should be identity if real_A is fed: ||G_B(A) - A||
            self.idt_B = self.netG_B(self.real_A)
            self.loss_idt_B = self.criterionIdt(self.idt_B, self.real_A) * lambda_A * lambda_idt

             # 疑问:为什么real_A与real_A转换后的图片idt_B相比?         

为什么不是这样?

        if lambda_idt > 0:
            # G_A should be identity if real_B is fed: ||G_A(B) - B||
            self.idt_A = self.netG_A(self.real_B)
            self.loss_idt_A = self.criterionIdt(self.idt_A, self.real_A) * lambda_B * lambda_idt
            
            # G_B should be identity if real_A is fed: ||G_B(A) - A||
            self.idt_B = self.netG_B(self.real_A)
            self.loss_idt_B = self.criterionIdt(self.idt_B, self.real_B) * lambda_A * lambda_idt

答案:内在逻辑

  • G_A网络对Real_B输入尽可能的不要转换,直接生成Real_B
  • G_B网络对Real_A输入尽可能的不要转换,直接还原成Real_A
     

1.2 D-A判决网络的结构与代码解读

(1)D-A网络的训练架构

  (2)D-A网络Loss代码实现

       def backward_D_A(self):
        """Calculate GAN loss for discriminator D_A"""
        fake_B = self.fake_B_pool.query(self.fake_B)
        self.loss_D_A = self.backward_D_basic(self.netD_A, self.real_B, fake_B)

1.3 D-A判决网络的结构与代码解读

(1)D-B网络的训练架构

 (2)D-B网络Loss代码实现

    def backward_D_B(self):
        """Calculate GAN loss for discriminator D_B"""
        fake_A = self.fake_A_pool.query(self.fake_A)
        self.loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A)

1.4 网络整体的优化算法

       def optimize_parameters(self):
        """Calculate losses, gradients, and update network weights; called in every training iteration"""
        # forward
        # compute fake images and reconstruction images.
        # 前向运算:包括4个运算,不是1个运算
        self.forward()      
        
        # 一起训练G_A and G_B
        # Ds require no gradients when optimizing Gs
        # 锁定D网络
        self.set_requires_grad([self.netD_A, self.netD_B], False) 
        # 复位G网络的梯度
        self.optimizer_G.zero_grad()  # set G_A and G_B's gradients to zero
        # 计算G网络的梯度
        self.backward_G()             # calculate gradients for G_A and G_B
        # 更新G网络的梯度
        self.optimizer_G.step()       # update G_A and G_B's weights

        # 独立训练D_A and D_B
        # 使能D网络训练
        self.set_requires_grad([self.netD_A, self.netD_B], True)
        # set D_A and D_B's gradients to zero
        # 复位D网络梯度
        self.optimizer_D.zero_grad()  
        # 计算D_A网络梯度
        self.backward_D_A()      # calculate gradients for D_A
        # 计算D_B网络梯度
        self.backward_D_B()      # calculate graidents for D_B
        # 更新D_A和D_B网络梯度
        self.optimizer_D.step()  # update D_A and D_B's weights
  • 锁定D网络,训练G_A和D_B网络,使得输出图片能够骗过D_A和D_B网络。
  • 开放D网络,训练D_A和D_B网络,  能够识别出输出图片是fake图片,即生成图片。
  • 重新锁定D网络,训练G_A和D_B网络,使得输出图片能够骗过D_A和D_B网络。
  • 依次类推,不断对抗、优化、迭代、更新,直到D网络无法判决出G网络输出的真假,得到以假乱真的效果。

作者主页(文火冰糖的硅基工坊):文火冰糖(王文兵)的博客_文火冰糖的硅基工坊_CSDN博客

本文网址:https://blog.csdn.net/HiWangWenBing/article/details/122106440

  • 5
    点赞
  • 91
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 4
    评论
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

文火冰糖的硅基工坊

你的鼓励是我前进的动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值