单侧无监督域适应的几何一致生成对抗网络
1.摘要:
无监督域映射旨在学习一个函数翻译域X图像到域Y图像,在配对样本缺少的情况下。在没有配对数据情况下,发现最优的是一个病态的问题,因此获得合理的解需要适合的约束。尽管一些著名的(prominent)约束,例如循环一致性(cycle consistency), 距离保留(distance preservation)成功地约束解空间,但是他们忽视了图像的特殊属性——简单的几何变换不会改变图像的语义结构。基于这些特殊的属性,作者提出一个几何一致性生成对抗网络(geometry-consistent generative adversarial networks -GcGAN), 进行单侧无监督域适应。GcGAN把原始图像以及对应几何变换图像作为模型的输入,并在新域中生成两个图像,并加上相应的几何一致性约束。几何一致性约束减少了可能解的空间并保证了在搜索空间中正确解。与基线模型GAN,以及最新的方法CycleGAN, DistanceGAN进行定性,定量的对比证明我们的方法的有效性。
该图片可视化CycleGAN, DistanceGAN , GcGAN之间的不同。
文章中,作者采用了两种常用的几何变换:顺时针旋转90度,垂直翻转。
2.结合代码讲论文:
论文中实验最优结果所使用的约束:对抗约束,几何一致性约束,循环一致性约束(GcGAN-rot + Cycle)
对抗约束:
循环一致性约束:
几何一致性约束:
距离约束 Distance constraint:
代码分析:
对代码链接中gc_cycle_gan_model.py文件中的代码进行分析:
模型的主要优化过程:输入数据,优化生成器,优化目标域判别器,优化源域判别器。
def optimize_parameters(self):
# forward
self.forward()
# G_AB
self.optimizer_G.zero_grad()
self.backward_G()
self.optimizer_G.step()
# D_B and D_gc_B
self.optimizer_D_B.zero_grad()
self.backward_D_B()
self.optimizer_D_B.step()
self.optimizer_D_A.zero_grad()
self.backward_D_A()
self.optimizer_D_A.step()
输入数据
def forward(self):
input_A = self.input_A.clone()
input_B = self.input_B.clone()
self.real_A = self.input_A
self.real_B = self.input_B
size = self.opt.fineSize
if self.opt.geometry == 'rot':
self.real_gc_A = self.rot90(input_A, 0)
self.real_gc_B = self.rot90(input_B, 0)
elif self.opt.geometry == 'vf':
inv_idx = torch.arange(size-1, -1, -1).long().cuda()
self.real_gc_A = torch.index_select(input_A, 2, inv_idx)
self.real_gc_B = torch.index_select(input_B, 2, inv_idx)
else:
raise ValueError("Geometry transformation function [%s] not recognized." % self.opt.geometry)
输入源域图像self.input_A,目标域图像self.input_B,根据事先设置的几何变换参数self.opt.gemetry进行顺时针旋转90度变换(在代码中体现,self.real_gc_A = self.rot90(input_A, 0),如果是逆时针旋转90度,则self.rot()函数第二个参数设置为1)。
优化生成器:
def backward_G(self):
# adversariasl loss
fake_B = self.netG_AB.forward(self.real_A)
pred_fake = self.netD_B.forward(fake_B)
loss_G_AB = self.criterionGAN(pred_fake, True)*self.opt.lambda_G
fake_gc_B = self.netG_AB.forward(self.real_gc_A)
pred_fake = self.netD_gc_B.forward(fake_gc_B)
loss_G_gc_AB = self.criterionGAN(pred_fake, True)*self.opt.lambda_G
fake_A = self.netG_BA.forward(self.real_B)
pred_fake = self.netD_A.forward(fake_A)
loss_G_AB += self.criterionGAN(pred_fake, True)*self.opt.lambda_G
fake_gc_A = self.netG_BA.forward(self.real_gc_B)
pred_fake = self.netD_gc_A.forward(fake_gc_A)
loss_G_gc_AB += self.criterionGAN(pred_fake, True)*self.opt.lambda_G
if self.opt.geometry == 'rot':
loss_gc = self.get_gc_rot_loss(fake_B, fake_gc_B, 0)
loss_gc += self.get_gc_rot_loss(fake_A, fake_gc_A, 0)
elif self.opt.geometry == 'vf':
loss_gc = self.get_gc_vf_loss(fake_B, fake_gc_B)
loss_gc += self.get_gc_vf_loss(fake_A, fake_gc_A)
if self.opt.identity > 0:
# G_AB should be identity if real_B is fed.
idt_A = self.netG_AB(self.real_B)
loss_idt = self.criterionIdt(idt_A, self.real_B) * self.opt.lambda_AB * self.opt.identity
idt_gc_A = self.netG_AB(self.real_gc_B)
loss_idt_gc = self.criterionIdt(idt_gc_A, self.real_gc_B) * self.opt.lambda_AB * self.opt.identity
idt_B = self.netG_BA(self.real_A)
loss_idt += self.criterionIdt(idt_B, self.real_A) * self.opt.lambda_AB * self.opt.identity
idt_gc_B = self.netG_BA(self.real_gc_A)
loss_idt_gc += self.criterionIdt(idt_gc_B, self.real_gc_A) * self.opt.lambda_AB * self.opt.identity
self.idt_A = idt_A.data
self.idt_gc_A = idt_gc_A.data
self.loss_idt = loss_idt.item()
self.loss_idt_gc = loss_idt_gc.item()
else:
loss_idt = 0
loss_idt_gc = 0
self.loss_idt = 0
self.loss_idt_gc = 0
rec_A = self.netG_BA(fake_B)
loss_cycle_A = self.criterionCycle(rec_A, self.real_A) * self.opt.lambda_AB
rec_gc_A = self.netG_BA(fake_gc_B)
loss_cycle_A += self.criterionCycle(rec_gc_A, self.real_gc_A) * self.opt.lambda_AB
rec_B = self.netG_AB(fake_A)
loss_cycle_B = self.criterionCycle(rec_B, self.real_B) * self.opt.lambda_AB
rec_gc_B = self.netG_BA(fake_gc_A)
loss_cycle_B += self.criterionCycle(rec_gc_B, self.real_gc_B) * self.opt.lambda_AB
loss_G = loss_G_AB + loss_G_gc_AB + loss_gc + loss_idt + loss_idt_gc + loss_cycle_A + loss_cycle_B
loss_G.backward()
self.fake_B = fake_B.data
self.fake_gc_B = fake_gc_B.data
self.fake_A = fake_A.data
self.fake_gc_A = fake_gc_A.data
self.loss_G_AB = loss_G_AB.item()
self.loss_G_gc_AB= loss_G_gc_AB.item()
self.loss_gc = loss_gc.item()
代码中self.real_A,fake_B, self.real_gc_A, fake_gc_B分别对应图中,
self.netG_AB 对应
几何一致性损失约束,体现在:
if self.opt.geometry == 'rot':
loss_gc = self.get_gc_rot_loss(fake_B, fake_gc_B, 0)
loss_gc += self.get_gc_rot_loss(fake_A, fake_gc_A, 0)
elif self.opt.geometry == 'vf':
loss_gc = self.get_gc_vf_loss(fake_B, fake_gc_B)
loss_gc += self.get_gc_vf_loss(fake_A, fake_gc_A)
def get_gc_rot_loss(self, AB, AB_gc, direction):
loss_gc = 0.0
if direction == 0:
AB_gt = self.rot90(AB_gc.clone().detach(), 1)
loss_gc = self.criterionGc(AB, AB_gt)
AB_gc_gt = self.rot90(AB.clone().detach(), 0)
loss_gc += self.criterionGc(AB_gc, AB_gc_gt)
else:
AB_gt = self.rot90(AB_gc.clone().detach(), 0)
loss_gc = self.criterionGc(AB, AB_gt)
AB_gc_gt = self.rot90(AB.clone().detach(), 1)
loss_gc += self.criterionGc(AB_gc, AB_gc_gt)
loss_gc = loss_gc*self.opt.lambda_AB*self.opt.lambda_gc
#loss_gc = loss_gc*self.opt.lambda_AB
return loss_gc
与进行L1范数计算流程:
在get_gc_rot_loss()损失中:
,分别对应与AB,AB_gc, AB_gt,AB_gc_gt。
作者使用几何一致性损失 + 旋转几何变换 + 循环一致性损失在城市景观数据集上去的最好的结果。要搞清楚作者在合成图像重建回原始图像过程中是否使用几何一致性损失?
根据代码以及论文的题目中“单侧无监督域适应”,所以作者只在翻译到或者翻译到过程中使用了几何一致性损失。
代码链接:
【1】https://github.com/hufu6371/GcGAN
作者怎么评估翻译图像的质量?
对于图像标签图翻译到图像的过程,作者认为高质量的翻译图像应该产生定性的分割结果,就像真实图像的分割结果一样。
因此作者使用pixel accuracy, class accuracy, mean IOU 评估翻译图像的分割结果,使用pix2pix 提供的预训练模型FCN-8s分割合成图像。