Cycle-GAN阅读笔记
Summary
Cycle-GAN是图片域与域之间的生成对抗网络,它能够在不同风格的图片中相互转换。相比较于Pix2pix-GAN,其最大的优点是对数据集的要求低,不需要成对的数据集便可以实现不同图片域的转换,这大大降低了生成对抗网络对于数据集的要求。
Contribution
- 传统的训练都是利用图片对进行训练,存在的问题就是图片对不好收集,比如从真实图片转换到MoNet风格的图片,论文提出了circle结构,实现了图片域<->图片域的转换
Related Work
- 生成对抗网络
- 图像到图像翻译
- 未配对的图像到图像翻译
- 周期一致性
- 神经风格转换
Approach and Model Architecture
Cycle-GAN的架构
其原理是对于两个图片域X, Y,训练G, F使得G(X)=Y且F(Y)=X,即完成了两个图片域的相互转换,训练损失即使得
D
X
D_X
DX区分不出F(y)和X,
D
Y
D_Y
DY区分不出G(x)和Y (这里使用了least-squares loss)。另外为了确保单一映射,作者在对抗损失之外还提出了循环一致性损失,即F(G(x))与x的一范式最小,G(F(y))和y也一样。
将这两者的Loss线性累加,最终给出的损失函数为
其中
作者采用的网络结构,采用了Johnson等人的生成网络的体系结构,对于鉴别器网络,作者使用了70×70个PatchGAN,这也是pix2pix使用的鉴别器。
训练方面,作者使用了两个小trick,第一个是对于对抗损失,作者使用了least-squares loss;第二个是,为了使得鉴别器不至于能力太超前,作者设置了图像缓冲区,即此刻的鉴别器识别的的是生成器生成的向前第n张图(作者为此设置了50图像的缓冲区),这样能够保证生成器和鉴别器能力相适应。
Comparison against baselines
为了评估Circle-GAN的生成效果,作者遵循了Pix2pix-GAN的评估方法,在不同任务不同指标情况下对生成图片进行评估,以下是评估结果
当然应该着重比较Pix2pix-GAN与Circle-GAN,因为他们的生成图片具有相同的意义,当然由于Pix2pix-GAN属于图像对的学习,所以理论上Circle-GAN性能更差一点也情有可原。
另一些数值指标
作者还研究了Circle-GAN中使用的Loss对最终生成图片评分的影响,结果发现对抗损失和循环一致性损失都起着很关键的作用,缺一不可。另外如果只训练一个方向,例如只训练
L
(
G
,
D
Y
,
X
,
Y
)
\mathcal{L}(G,D_Y,X,Y)
L(G,DY,X,Y),会常引起训练不稳定性并导致模式崩溃,特别是对于没有训练的那个方向。
为了进一步直观地说明Loss的缺失对生成图片的影响,作者给出另一个实验结果,图中可知缺失某些Loss对最后的生成缺失产生了一些影响。
作者还测试了一下图像的重构,即 F ( G ( x ) ) F(G(x)) F(G(x)),并将其与原来的图片x进行比较,结果发现重构的图片大致能够恢复出输入图片的细节,表明域与域之间的单一映射较为准确。
接着作者进行了Pix2pix-GAN一样的实验,结论是Circle-GAN能产生接近Pix2pix的效果,这其实是挺难得的了,毕竟Circle-GAN是没有配对监督的方法。
我们注意到培训数据的翻译通常比测试数据的翻译更具吸引力,更多的实验发布在https://junyanz.github.io/CycleGAN/中,里面的实验都很有趣。
在某些情况下,给GAN加上额外的损失十分有必要,比如在关于艺术图像转变的示例中,可以增加一个L1范式作为损失,确保较大程度保留颜色信息。
更多艺术风格的转换图
下一个实验是数据增强:增强景深
艺术图像转换+1
不同方法在图像转换中的比较,Circle-GAN取得了领先的水平。
虽然Circle-GAN能在大多数图像转换的任务中表现出良好的效果,但是在一些特定的任务中,它也存在着很多不足。在左边猫狗转换中,Circle-GAN只能进行微小的改变,而在右边马到斑马的转化中,由于数据集的问题,人也染上了斑马的条纹。
我们还观察到配对训练数据可达到的结果与未配对方法可达到的结果之间仍存在差距。在某些情况下,这种差距可能很难消除,甚至无法消除:例如,我们的方法有时会在photos→labels任务的输出中置换用于树和建筑物的标签。要解决这种歧义,可能需要某种形式的弱语义监督。集成弱或半监督的数据可能会导致翻译器的功能更加强大,但其成本仅为全监督系统的注释成本的一小部分。
最后一些实验图
Code
Cycle-GAN的代码主要参考了https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix#custom-model-and-dataset
给出最关键训练的部分代码如下,做个了总结
- 该Cycle-GAN的训练部分
- 训练鉴别器
- 对抗损失 (
M
S
E
?
MSE?
MSE?)
- 对于 X X X域到 Y Y Y域,训练鉴别器 D Y D_Y DY,使得 D Y D_Y DY能判别 G X ( x ) G_X(x) GX(x)为假, y y y为真
- 对于 Y Y Y域到 X X X域,训练鉴别器 D X D_X DX,使得 D X D_X DX能判别 G Y ( y ) G_Y(y) GY(y)为假, x x x为真
- 对抗损失 (
M
S
E
?
MSE?
MSE?)
- 训练生成器
- 对抗损失 (
M
S
E
?
MSE?
MSE?)
- 对于 X X X域到 Y Y Y域,训练生成器 G X G_X GX,使得 D Y D_Y DY能判别 G X ( x ) G_X(x) GX(x)为真
- 对于 Y Y Y域到 X X X域,训练生成器 G Y G_Y GY,使得 D X D_X DX能判别 G Y ( y ) G_Y(y) GY(y)为真
- 循环一致性损失 (
L
1
L1
L1)
- 对于 X X X域到 X X X域,训练生成器 G X G_X GX、 G Y G_Y GY,使得 G Y ( G X ( x ) ) G_Y(G_X(x)) GY(GX(x))接近于 x x x
- 对于 Y Y Y域到 Y Y Y域,训练生成器 G Y G_Y GY、 G X G_X GX,使得 G X ( G Y ( y ) ) G_X(G_Y(y)) GX(GY(y))接近于 y y y
- 保留原图像信息损失 (
L
1
L1
L1) (可选)
- 对于 X X X域到 Y Y Y域,训练生成器 G X G_X GX,使得 G X ( x ) G_X(x) GX(x)接近于 x x x
- 对于 Y Y Y域到 X X X域,训练生成器 G Y G_Y GY,使得 G Y ( y ) G_Y(y) GY(y)接近于 y y y
- 对抗损失 (
M
S
E
?
MSE?
MSE?)
- 训练鉴别器
if self.isTrain:
if opt.lambda_identity > 0.0: # only works when input and output images have the same number of channels
assert(opt.input_nc == opt.output_nc)
self.fake_A_pool = ImagePool(opt.pool_size) # create image buffer to store previously generated images
self.fake_B_pool = ImagePool(opt.pool_size) # create image buffer to store previously generated images
# define loss functions
self.criterionGAN = networks.GANLoss(opt.gan_mode).to(self.device) # define GAN loss.
self.criterionCycle = torch.nn.L1Loss()
self.criterionIdt = torch.nn.L1Loss()
# initialize optimizers; schedulers will be automatically created by function <BaseModel.setup>.
self.optimizer_G = torch.optim.Adam(itertools.chain(self.netG_A.parameters(), self.netG_B.parameters()), lr=opt.lr, betas=(opt.beta1, 0.999))
self.optimizer_D = torch.optim.Adam(itertools.chain(self.netD_A.parameters(), self.netD_B.parameters()), lr=opt.lr, betas=(opt.beta1, 0.999))
self.optimizers.append(self.optimizer_G)
self.optimizers.append(self.optimizer_D)
def set_input(self, input):
"""Unpack input data from the dataloader and perform necessary pre-processing steps.
Parameters:
input (dict): include the data itself and its metadata information.
The option 'direction' can be used to swap domain A and domain B.
"""
AtoB = self.opt.direction == 'AtoB'
self.real_A = input['A' if AtoB else 'B'].to(self.device)
self.real_B = input['B' if AtoB else 'A'].to(self.device)
self.image_paths = input['A_paths' if AtoB else 'B_paths']
def forward(self):
"""Run forward pass; called by both functions <optimize_parameters> and <test>."""
self.fake_B = self.netG_A(self.real_A) # G_A(A)
self.rec_A = self.netG_B(self.fake_B) # G_B(G_A(A))
self.fake_A = self.netG_B(self.real_B) # G_B(B)
self.rec_B = self.netG_A(self.fake_A) # G_A(G_B(B))
def backward_D_basic(self, netD, real, fake):
"""Calculate GAN loss for the discriminator
Parameters:
netD (network) -- the discriminator D
real (tensor array) -- real images
fake (tensor array) -- images generated by a generator
Return the discriminator loss.
We also call loss_D.backward() to calculate the gradients.
"""
# Real
pred_real = netD(real)
loss_D_real = self.criterionGAN(pred_real, True)
# Fake
pred_fake = netD(fake.detach())
loss_D_fake = self.criterionGAN(pred_fake, False)
# Combined loss and calculate gradients
loss_D = (loss_D_real + loss_D_fake) * 0.5
loss_D.backward()
return loss_D
def backward_D_A(self):
"""Calculate GAN loss for discriminator D_A"""
fake_B = self.fake_B_pool.query(self.fake_B)
self.loss_D_A = self.backward_D_basic(self.netD_A, self.real_B, fake_B)
def backward_D_B(self):
"""Calculate GAN loss for discriminator D_B"""
fake_A = self.fake_A_pool.query(self.fake_A)
self.loss_D_B = self.backward_D_basic(self.netD_B, self.real_A, fake_A)
def backward_G(self):
"""Calculate the loss for generators G_A and G_B"""
lambda_idt = self.opt.lambda_identity
lambda_A = self.opt.lambda_A
lambda_B = self.opt.lambda_B
# Identity loss
if lambda_idt > 0:
# G_A should be identity if real_B is fed: ||G_A(B) - B||
self.idt_A = self.netG_A(self.real_B)
self.loss_idt_A = self.criterionIdt(self.idt_A, self.real_B) * lambda_B * lambda_idt
# G_B should be identity if real_A is fed: ||G_B(A) - A||
self.idt_B = self.netG_B(self.real_A)
self.loss_idt_B = self.criterionIdt(self.idt_B, self.real_A) * lambda_A * lambda_idt
else:
self.loss_idt_A = 0
self.loss_idt_B = 0
# GAN loss D_A(G_A(A))
self.loss_G_A = self.criterionGAN(self.netD_A(self.fake_B), True)
# GAN loss D_B(G_B(B))
self.loss_G_B = self.criterionGAN(self.netD_B(self.fake_A), True)
# Forward cycle loss || G_B(G_A(A)) - A||
self.loss_cycle_A = self.criterionCycle(self.rec_A, self.real_A) * lambda_A
# Backward cycle loss || G_A(G_B(B)) - B||
self.loss_cycle_B = self.criterionCycle(self.rec_B, self.real_B) * lambda_B
# combined loss and calculate gradients
self.loss_G = self.loss_G_A + self.loss_G_B + self.loss_cycle_A + self.loss_cycle_B + self.loss_idt_A + self.loss_idt_B
self.loss_G.backward()
def optimize_parameters(self):
"""Calculate losses, gradients, and update network weights; called in every training iteration"""
# forward
self.forward() # compute fake images and reconstruction images.
# G_A and G_B
self.set_requires_grad([self.netD_A, self.netD_B], False) # Ds require no gradients when optimizing Gs
self.optimizer_G.zero_grad() # set G_A and G_B's gradients to zero
self.backward_G() # calculate gradients for G_A and G_B
self.optimizer_G.step() # update G_A and G_B's weights
# D_A and D_B
self.set_requires_grad([self.netD_A, self.netD_B], True)
self.optimizer_D.zero_grad() # set D_A and D_B's gradients to zero
self.backward_D_A() # calculate gradients for D_A
self.backward_D_B() # calculate graidents for D_B
self.optimizer_D.step() # update D_A and D_B's weights