qiuzitao深度学习之PyTorch实战(十四)

史上最简单、实际、通俗易懂的PyTorch实战系列教程!(新手友好、小白请进、建议收藏)

CycleGan网络

你可能听过AI换脸,明星换脸,那你知道它是怎么合成的么?CycleGan网络带你见见世面。

在这里插入图片描述

一、CycleGan网络所需数据

我们CycleGan网络不需要两个一一配对的数据,照样可以进行训练和预测。不需要知道一样形态的斑马和马,也可以把马造出斑马。配对的意思就是如下图的Paired下面的白色鞋子和有颜色的鞋子,他们除了颜色不同,其他的特征是一样的配对的。
在这里插入图片描述
只需要有trainA和trainB就行。
在这里插入图片描述

二、CycleGan整体网络架构

需要Gab和Gba,斑马生成马,再还原回去,还原成斑马,这样才可以有形态一样的特征,才可以进行图像合成。如果只有单项就可能生成不了一样形态的了。

在这里插入图片描述
整体网络架构分成两部分,一部分把普通马作为输入的,另一部分是以斑马作为输入的,两个数据集,前着走一遍,倒着走一遍。一部分是先普通马生成斑马(A2B),然后再斑马还原成普通马(B2A),另一部分是先斑马生成普通马(B2A),然后再普通马还原成斑马(A2B)。

整体网络架构中有四个网络,两个生成网络和两个判别网络。G生成网络中有Gab(就是A2B)和Gba(就是B2A),D判别网络也有两个,一个Ga,一个Gb。有四种损失函数,能达到何种结果,完全是你的损失函数决定的,CycleGan网络最核心的当然就是Cycle损失函数了。
在这里插入图片描述
然后这里的D判别网络和之前的有点特别,它是PatchGAN。它得到的不是一个sigmode数值,而是一个 N * N 的特征图。

在这里插入图片描述

三、基于CycleGan开源项目实战图像合成

GitHub地址:https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix
datasets地址:http://efrosgans.eecs.berkeley.edu/cyclegan/pretrained_models/(这个网站提供很多数据集可以给大家实战)
人家训练好的模型:http://efrosgans.eecs.berkeley.edu/cyclegan/pretrained_models/(我们可以直接拿来用,做为预训练模型)

在这里插入图片描述
maps就是很多地图的数据集,我们要用的是 horse2zebre 就是马和斑马的数据集。

在这里插入图片描述
上面这个是很多预训练模型,可以直接拿来用。

这里我们下载horse2zebra的数据集和预训练模型来跑。

3.1、数据读取与预处理操作

首先我们要设置训练的参数,因为这个开源项目里面是包含了两个项目的,我们要指定一个。

--dataroot ./datasets/horse2zebra 
--name horse2zebra_cyclegan 
--model cycle_gan

在这里插入图片描述
下载的数据集放在datasets文件夹里面,pth预训练模型放在checkpoints文件夹里面。
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述

下载好需要的东西和指定好参数后就可以开始训练了。

dataset_class = find_dataset_using_name(opt.dataset_mode)
self.dataset = dataset_class(opt)

找到指定项目,这里就是看你指定哪些参数,就跑哪个项目。

self.dir_A = os.path.join(opt.dataroot, opt.phase + 'A')  # create a path '/path/to/data/trainA'
self.dir_B = os.path.join(opt.dataroot, opt.phase + 'B')  # create a path '/path/to/data/trainB'

这个就是你的 trainA 和 trainB,导向两个训练集路径的代码。

self.A_paths = sorted(make_dataset(self.dir_A, opt.max_dataset_size))   # load images from '/path/to/data/trainA'
self.B_paths = sorted(make_dataset(self.dir_B, opt.max_dataset_size))    # load images from '/path/to/data/trainB'

这个是你trainA和trainB的数据个数,一共是1096个数据。

input_nc = self.opt.output_nc if btoA else self.opt.input_nc       # get the number of channels of input image
output_nc = self.opt.input_nc if btoA else self.opt.output_nc      # get the number of channels of output image

这个是两个输入输出的通道,一般彩色就三通道。

3.2、生成网络模块构造
BaseModel.__init__(self, opt)
        # specify the training losses you want to print out. The training/test scripts will call <BaseModel.get_current_losses>
        self.loss_names = ['D_A', 'G_A', 'cycle_A', 'idt_A', 'D_B', 'G_B', 'cycle_B', 'idt_B']
        # specify the images you want to save/display. The training/test scripts will call <BaseModel.get_current_visuals>
        visual_names_A = ['real_A', 'fake_B', 'rec_A']
        visual_names_B = ['real_B', 'fake_A', 'rec_B']
        if self.isTrain and self.opt.lambda_identity > 0.0:  # if identity loss is used, we also visualize idt_B=G_A(B) ad idt_A=G_A(B)
            visual_names_A.append('idt_B')
            visual_names_B.append('idt_A')

        self.visual_names = visual_names_A + visual_names_B  # combine visualizations for A and B

损失函数:两个G网络(G_A,G_B),两个D网络(D_A,D_B),循环一致性,所以有cycle_A,cycle_B,idt_A,idet_B(保持输入输出的相同)

 model = [nn.ReflectionPad2d(3),
                 nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=use_bias),
                 norm_layer(ngf),
                 nn.ReLU(True)]

nn.ReflectionPad2d(3):这个操作其实就是padding操作。我们点进去看可以看到padding结果。
在这里插入图片描述

比如你输入是[[[0.,1.,2.],      经过nn.replicationPad2d操作后就在周围加了两层(默认参数是翻转加了两层)
       [3.,4.,5.],
       [6.,7.,8.]]]])
       
生成网络模块其实就是特征图卷积,尺寸越来越小,但是数量越来越多的堆叠起来,这样子去卷积,然后再进行反卷积复原,也就是特征图越来越少,特征图尺寸也越来越大直到和原来的图片一样大。

3.3、判别网络模块构造
if self.isTrain:  # define discriminators 定义判别器
    self.netD_A = networks.define_D(opt.output_nc, opt.ndf, opt.netD,
                                    opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids)
    self.netD_B = networks.define_D(opt.input_nc, opt.ndf, opt.netD,
                                    opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids)

点进来看网络里面的结构

在这里插入图片描述
在这里插入图片描述

norm_layer = get_norm_layer(norm_type=norm):先做一个基于颜色通道的归一化。

然后就是先做一个卷积,输入是3颜色通道,得到64个特征图,就是卷积–Normalization(归一化)-- 激活函数。

无论特征图卷积成多少×多少的,但是最后一定要生成一个 NN1 的结果给判别就好。

3.4、损失函数 identity loss计算方法

在这里插入图片描述

self.fake_A_pool = ImagePool(opt.pool_size)  # create image buffer to store previously generated images
self.fake_B_pool = ImagePool(opt.pool_size)  # create image buffer to store previously generated images

上图(代码):先说一下,这里是建立一个缓存区域,我们要从A–B,中间会生成一个缓存区域Gab,这个有用的,要和我们后面的Gba进行比对的,也就是损失值计算。

self.criterionGAN = networks.GANLoss(opt.gan_mode).to(self.device)  # define GAN loss.
            self.criterionCycle = torch.nn.L1Loss()
            self.criterionIdt = torch.nn.L1Loss()

上图(代码):L1Loss这里就是按照绝对值去计算一个差异性,我们希望生成的假B和原来的真B是差异越小越好。

在这里插入图片描述
上图(代码):看到MSEloss这里,这里就是计算我们输入的和还原的结果是不是一样的。

在这里插入图片描述

def forward(self):
        """Run forward pass; called by both functions <optimize_parameters> and <test>."""
    self.fake_B = self.netG_A(self.real_A)  # G_A(A)
    self.rec_A = self.netG_B(self.fake_B)   # G_B(G_A(A))
    self.fake_A = self.netG_B(self.real_B)  # G_B(B)
    self.rec_B = self.netG_A(self.fake_A)   # G_A(G_B(B))

上图(代码):前向传播,先有一个实际的数据A(real_A)传进去网络Ga,生成一个假的B(fake_B),接下来就是还原了,把假的B(fake_B)传入Gb还原成原来的A数据,这就是从前往后走的;接下来就是从后往前走的,先由真实数据B传入Gb生成一个假的A(fake_A),然后再进行一个还原,还原成真实数据B。

在这里插入图片描述

# G_A and G_B
    self.set_requires_grad([self.netD_A, self.netD_B], False)  # Ds require no gradients when optimizing Gs
    self.optimizer_G.zero_grad()  # set G_A and G_B's gradients to zero
    self.backward_G()             # calculate gradients for G_A and G_B
    self.optimizer_G.step()       # update G_A and G_B's weights
# D_A and D_B
    self.set_requires_grad([self.netD_A, self.netD_B], True)
    self.optimizer_D.zero_grad()   # set D_A and D_B's gradients to zero
    self.backward_D_A()      # calculate gradients for D_A
    self.backward_D_B()      # calculate graidents for D_B
    self.optimizer_D.step()  # update D_A and D_B's weights

上图(代码):在训练过程中,只训练更新一个模块,不是两个一起训练更新的,当我们在训练生成器时,我们得先做一个限制,训练生成器就只训练生成器,此时判别器是没有工作的。所以你可以看到后面的False 和 True。

在这里插入图片描述

上图(代码):生成器:self.loss_cycle_A = self.criterionCycle(self.rec_A, self.real_A) * lambda_A 用绝对值的方法去算原始的(真实的)A和还原回去的 A 的loss值。self.loss_cycle_B = self.criterionCycle(self.rec_B, self.real_B) * lambda_B 用绝对值的方法去算原始的(真实的)B和还原回去的 B 的loss值。最后算得一个生成器的loss结果。

在这里插入图片描述
上图(代码):判别器:把真的预测成1,假的预测为0。loss_D = (loss_D_real + loss_D_fake) * 0.5把真的和假的的loss算在一起取个加权平均,得到判别器的loss。

  • 5
    点赞
  • 11
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

qiuzitao

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值