迁移学习论文(四):Generate To Adapt: Aligning Domains using Generative Adversarial Networks论文原理及复现工作

前言

  • 本文属于我迁移学习专栏里的一篇,该专栏用于记录本人研究生阶段相关迁移学习论文的原理阐述以及复现工作。
  • 本专栏的文章主要内容为解释原理,论文具体的翻译及复现代码在文章的github中。

原理阐述

文章介绍

  • 这篇文章于2018年发表在CVPR,作者是Swami Sankaranarayanan,Yogesh Balaji,Carlos D. Castillo,Rama Chellappa。
  • 这篇文章的主要贡献是提出了一个能够直接学习联合特征空间的对抗图像生成的无监督领域自适应方法。该方法与之前方法相比的独特之处在于,同时使用了生成式和判别式两种思想,利用图像生成的对抗过程学习一个源域和目标域特征分布最小化的特征空间。

模型结构

  • 模型是这样的:
    在这里插入图片描述

组件作用总论

  • 上图中的C是分类器,F是特征提取器,G是生成器,D是鉴别器。
  • F用于提取源域和目标域图像特征,C用于源域数据标签预测,G用于生成混淆视听的生成数据,D用于鉴别真实数据和生成数据。
  • 模型总共分为两个支流,支流一是标签预测网络,比较简单,源域数据经过F得到特征,然后输入到C中进行分类,并计算交叉熵分类损失
  • 支流二是生成对抗网络,G用于生成数据混淆D的鉴别任务G根据F提取的源域数据(注意不是目标域数据)特征进行转置卷积得到源域生成数据,它们都不是真实数据,所以标签为0,G计算这些源域生成数据投入到D后的域分类损失和标签分类损失;D用于鉴别,它计算三个损失,首先是源域真实数据直接投入到D,计算域分类损失和标签分类损失,然后是G提取到的目标域生成数据投入到D计算的域分类损失。最后是G提取到的源域生成数据投入到D计算的域分类损失

鉴别器D

  • 作用:
    判定源域生成数据为假——源域生成数据域分类损失(假,域标签为0)
    判定源域真实数据为真,分类误差最小化——源域真实数据域分类损失(真,域标签为1)+源域真实数据标签分类损失
    判定目标域生成数据为假——目标域生成数据域分类损失(假,域标签为0)
  • 其中,
    源域真实数据域分类损失+源域生成数据域分类损失(假类,对应标签为0) 对应论文中的:在这里插入图片描述
    源域真实数据标签分类损失对应论文中的:在这里插入图片描述
    目标域生成数据域分类损失对应论文中的:在这里插入图片描述
    max是对的,log里的越大,loss越大。
  • 代码
    ……

生成器G

  • 作用
    让D误以为源域生成数据为真——源域生成数据域分类损失(为了混淆,所以域标签为1)
    源域生成数据分类误差最小化——源域生成数据标签分类损失
  • 其中
    源域生成数据域分类损失(为了混淆,所以域标签为1)和源域生成数据标签分类损失对应论文中的:在这里插入图片描述
    min,则前一部分越大越好(不看负号),后一部分越小越好,正确。
  • 代码
    ……

分类器C

  • 作用
    源域真实数据分类误差最小化——源域真实数据分类损失(上面提到的鉴别器中的分类损失与分类器C的参数是不共享的,也就是说是两个意义相同但参数不同的分类器,鉴别器中有一个小分类器用于自己使用而不是使用C来计算源域真实数据分类损失)
  • 其中
    对应论文中的:
    在这里插入图片描述
  • 代码
    ……

特征提取器F

  • 作用
    源域真实数据分类误差最小化——源域真实数据标签分类损失
    源域生成数据分类误差最小化——源域生成数据标签分类损失
    让D误以为目标域生成数据为真——目标域生成数据域分类损失(为了混淆,所以域标签为1)
  • 其中
    上面凡是涉及到特征提取的步骤在进行反向传播的时候都会修改特征提取器F的参数,只有直接使用源域数据的源域真实数据域分类损失 是与F无关的。
    源域真实数据标签分类损失对应论文中的在这里插入图片描述
    源域生成数据标签分类损失对应论文中的
    在这里插入图片描述
    目标域生成数据域分类损失(为了混淆,所以域标签为1) 对应论文中的在这里插入图片描述
  • 代码
    ……

警告

  • 从上面对特征提取器的介绍可以看出,这篇论文中的各个组件的反向传播是独立的,很多loss都涉及到了F,但是F的反向传播只与上述三个损失有关,其他损失更新参数的时候不会更新F的参数。也就是说我们在写代码设计模型的时候要把上述四个组件分开写,分别进行反向传播。

原理分析

  • 现在我们重新来看它的结构图
    在这里插入图片描述
  • 我们先来分析D的原理,它计算三个损失,分别是L1-源域生成数据域分类损失L2-源域真实数据域分类损失(真,域标签为1)+L3-源域真实数据标签分类损失,和L4-目标域生成数据域分类损失(假,域标签为0)。L2与L3相对L1与L4会指导D的参数向着对真实数据标签分类和域分类越来越准确的方向调整,而对于目标域生成的数据的域分类损失向着负类调整,也就是说能使得D更好的分别源域和目标域生成数据的区别以及预测。
  • 而生成器G则是由源域生成数据域分类损失(为了混淆,所以域标签为1)源域生成数据标签分类损失约束,是为了让生成的数据接近源数据,这与D是相冲突的。
  • C就不用说了,向着源域数据标签分类损失减小的方向调整参数。
  • F是重点,因为它的参数和C的参数一样影响着模型测试精度,同时它又比C重要,因为C的构造比较简单。F是基础,它受L1-源域真实数据标签分类损失、L2-源域生成数据标签分类损失和L3-目标域生成数据域分类损失(为了混淆,所以域标签为1) 约束,其中L1就是调整F参数来减小源域标签分类损失,L2和L3的存在使得F的参数偏向于使得G生成混淆数据,也就是说偏向于使得目标域生成数据和源域生成数据相近,以及偏向于源域生成数据标签预测准确,这与D是冲突的,F的参数调整是偏向于G的。
  • 但是到这里我想说一下D和G存在的意义,真正参与测试精度的是F和C,而F的参数调整依赖于G和C返回的损失,而G返回的损失中的L2和L3依赖于D,D在这里就好比一个正直的老师,它要区分两样正版和盗版(分别来自源域和G),而正版的源域数据是不受任何参数影响的,它就是样本,直接传入D去指导D的参数,教会这个老师怎么识别正版。而初始阶段,同时来自G的盗版和正版是存在差异的,D可以轻易发现,并且调整内部参数使得自己能更容易发现盗版和正版的区别。所谓D中的内部参数,其实就是来自于D内部的域分类器和标签分类器。而不巧的是,这位正直的老师身边的域分类器和标签分类器被G这个大坏蛋利用了,它通过混淆标签1传回来自域分类器和标签分类器的loss,这个loss用于指导大坏蛋G的内部参数,使得它生成的盗版数据(源域生成数据)在D内部的标签分类和域分类上的表现变好了,这让D这位正直的老师很苦恼,为什么盗版和正版数据变的越来越相似了,殊不知,G利用了D内部的域分类器和标签分类器。这样G就根据窃取的信息改变内部参数使得它生成的数据和正版数据(源域数据)在域分类和标签分类上的表现尽可能相似。这就是G和D的恩怨,又由于F参数的改变是偏向G的,使得D更没有指导信息了,源域数据垂直来指导D的标签分类器和域分类器,但是这两个分类器又被G利用了,直接导致D改进G跟着改进,而高一级的监管机构F又是偏向于G的。而我们进一步思考,这些对于测试精度又有什么影响?事实上只要源域和目标域的同类数据经过F之后生成的特征图特征分布越相似对于测试精度的提高就越有帮助,而在F和G中源域生成数据和目标域生成数据其实是同一地位的,D中的域分类器都把它们标记为负类,所以说经过F(src)和F(tar)对于G来说是差不多的,也就意味着参数的调整偏向于使F(src)和F(tar)相近调整,这就是咱们的目标。

总结

  • 我总感觉这个模型有些地方设计的比较繁琐,后续我会再深入研究下这种基于对抗的模型。
  • 代码有待补充。
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

CtrlZ1

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值
>