CycleGAN论文解读及代码实现

paper: https://arxiv.org/pdf/1703.10593.pdf
github: https://github.com/aitorzip/PyTorch-CycleGAN

1 cycleGAN 小结

  • 网络:
    生成器2个:G_A,G_B
    判别器两个: D_A,D_B
  • 损失函数8个
    6个生成器损失函数
    2个判别器损失函数

1.1 数据

  • fake_B
    原始A,经过生成器G_A,生成fake_B
    A->G_A = fake_B
  • rec_A
    原始A,经过生成器G_A,生成fake_B,再经过生成器G_B,生成重构数据rec_A
    A->G_A ->G_B = rec_A
  • fake_A
    B->G_B = fake_A
    原始B,经过生成器G_B,生成fake_A
  • rec_B
    B->G_B ->G_A = rec_B
    原始B,经过生成器G_B,生成fake_A,再经过生成器G_A,生成重构数据rec_B

1.2 损失函数

6个生成器损失,2个判别器损失
1)6个生成器损失:

  • 生成器一致损失 2个
    数据B经过生成器G_A,后生成的B,与原始B距离最小。A同理
    ① B-> G_A->B’ : 使得B 与B’距离最小
    ② A-> F_B->A’ : 使得A 与A’距离最小
  • 生成器损失 2个
    生成器生成的数据,让判别器都判别为真
    ③ MSELoss(D_A(fake_B), True)
    ④ MSELoss(D_B(fake_A), True)
  • 循环一致损失 2个
    ⑤ 原始A,经过生成器G_A,生成fake_B,再经过生成器G_B,生成重构数据rec_A
    A->G_A ->G_B = rec_A
    ⑥原始B,经过生成器G_B,生成fake_A,再经过生成器G_A,生成重构数据rec_B
    B->G_B ->G_A = rec_B

2) 2个判别器损失

  • 判别器损失 2个
    使真实图片为判别为真,假图片判别为假
    ① D_A
    pred_real = D_A(real); pred_fake= D_A(fake)
    MSELoss(pred_real, True)+MSELoss(pred_fake, False)
    ②D_B
    pred_real = D_B(real); pred_fake= D_B(fake)
    MSELoss(pred_real, True)+MSELoss(pred_fake, False)

2 模型架构

  • 两个生成网络:
    G: X——> Y ,输入X生成Y
    F: Y——> X :输入Y生成X
  • 两个判别网络:
    D_A: 用于区分真实A和 F(B)生成的假A.
    D_B:用于区分真实B和 G(A)生成的假B.

在这里插入图片描述

3 损失函数

3.1 Adversarial Loss

对抗损失:

  • 对于生成器 G: X——> Y
    生成器G_X: 最小化以下目标函数
    对于判别器D_Y:最大化以下目标函数
    在这里插入图片描述
  • 对于生成器 F: Y——> X,损失函数同上
    生成器F_Y,使判别器D_X判断为真
    对于判别器D_X:是真实X判断为真,F_Y生成的X,判断为假。
    L G A N ( F , D X , Y , X ) L_{GAN}(F,D_X,Y,X) LGAN(F,DX,Y,X)

3.2 Cycle Consistency Loss

循环一致损失,即 X 经过生成器G_x后 得到Y,Y再过F_Y生成X,使得前后生成的X距离最小。
1) 前向一致损失
即从x 经过网络后还原为x的过程
X − > G ( x ) − > F ( G ( x ) ) = X X -> G(x) -> F(G(x)) =X X>G(x)>F(G(x))=X

2)反向一致损失
即y从经过网络后还原为y的过程
Y − > F ( y ) − > G ( F ( y ) ) = Y Y -> F(y) -> G(F(y)) =Y Y>F(y)>G(F(y))=Y

在这里插入图片描述

3.3 Full Objective

在这里插入图片描述

4 代码实现

4.1网络结构

  • 1 生成器A :
    netG_A:可以选用resnet,或者unet网络
    输入数据A,生成数据B
self.netG_A = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.norm,
                                        not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids)
  • 2 生成器B:
    netG_B: 与netG_A网络一样
    输入数据B,生成数据A
 self.netG_B = networks.define_G(opt.output_nc, opt.input_nc, opt.ngf, opt.netG, opt.norm,
                                        not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids)
  • 3 判别器A
self.netD_A = networks.define_D(opt.output_nc, opt.ndf, opt.netD,
                                            opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids)

  • 4 判别器B
 self.netD_B = networks.define_D(opt.input_nc, opt.ndf, opt.netD,
                                            opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids)


5 损失

5.1 前向传播数据

  • self.fake_B
    原始A,经过生成器G_A,生成fake_B
    A->G_A = fake_B
  • self.rec_A
    原始A,经过生成器G_A,生成fake_B,再经过生成器G_B,生成重构数据rec_A
    A->G_A ->G_B = rec_A
  • self.fake_A
    B->G_B = fake_A
    原始B,经过生成器G_B,生成fake_A
  • self.rec_B
    B->G_B ->G_A = rec_B
    原始B,经过生成器G_B,生成fake_A,再经过生成器G_A,生成重构数据rec_B
self.fake_B = self.netG_A(self.real_A)  # G_A(A)
self.rec_A = self.netG_B(self.fake_B)   # G_B(G_A(A))
self.fake_A = self.netG_B(self.real_B)  # G_B(B)
self.rec_B = self.netG_A(self.fake_A)   # G_A(G_B(B))

5.2 生成器一致损失

数据B经过生成器G_A,后生成的B,与原始B距离最小。
B-> G_A->B’ : 使得B 与B’距离最小
A-> F_B->A’ : 使得A 与A’距离最小

self.idt_A = self.netG_A(self.real_B)
self.loss_idt_A= self.L1Loss(self.idt_A, self.real_B)

self.idt_B = self.netG_B(self.real_A)
self.loss_idt_B = self.L1Loss(self.idt_B, self.real_A)

5.2 生成器损失

生成器生成的数据,让判别器都判别为真

(备注:判别器输出不是一个值,而是一个矩阵,需要使判别器输出矩阵每一个值都接近1)

# GAN loss D_A(G_A(A))
self.loss_G_A = self.MSELoss(self.netD_A(self.fake_B), True)
# GAN loss D_B(G_B(B))
self.loss_G_B = self.criterioMSELossGAN(self.netD_B(self.fake_A), True)

5.3 循环一致损失

使得重构的A与原始A距离最近,使用L1Loss

  • self.rec_A
    原始A,经过生成器G_A,生成fake_B,再经过生成器G_B,生成重构数据rec_A
    A->G_A ->G_B = rec_A
  • self.rec_B
    B->G_B ->G_A = rec_B
    原始B,经过生成器G_B,生成fake_A,再经过生成器G_A,生成重构数据rec_B
# Forward cycle loss || G_B(G_A(A)) - A||
self.loss_cycle_A = self.L1Loss(self.rec_A, self.real_A) 
# Backward cycle loss || G_A(G_B(B)) - B||
self.loss_cycle_B = self.L1Loss(self.rec_B, self.real_B) 

5.4 生成器总loss

上面6个生成器损失求和即为总的生成损失函数

self.loss_G = self.loss_G_A + self.loss_G_B + self.loss_cycle_A + self.loss_cycle_B + self.loss_idt_A + self.loss_idt_B

5.5 判别器损失

判别器:使真实图片为判别为真,假图片判别为假

pred_real = netD(real)
loss_D_real = self.MSELoss(pred_real, True)
# Fake
pred_fake = netD(fake.detach())
loss_D_fake = self.MSELoss(pred_fake, False)
# Combined loss and calculate gradients
loss_D = (loss_D_real + loss_D_fake) * 0.5
  • 0
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值