分布对齐 目标函数 迁移学习_伪文艺程序员的“迁移学习”啃读(G)

本文详细介绍了深度迁移学习的多种方法,包括DeepCoral、DRCN、CoGAN、RTN、JAN、ADDA、CDAN、MCD和GTA等,重点解析了这些算法的原理和应用,如DeepCoral的Coral损失、CoGAN的权重共享、GTA的AC-GAN框架等。
摘要由CSDN通过智能技术生成
00faae7cfe9cf720a0ec802b5d6a2f89.gif点击“网络人工智能园地”关注我们~ 49475e2853e1048b4ba580ea5dc196a1.png

0fd99b0e55fa4b8bad88c6d84333ef7b.png

作者:新春

————————

计算机软件新技术国家重点实验室伪文艺程序员

既可提刀立码,行遍天下

又可调参炼丹,卧于隆中

e9b069dceb300cb69b6b531c79abd7c4.png

注明:本文首发在NAIE《网络人工智能园地》公众号,如有转载,请注明出处。

深度迁移方法主要是指基于深度网络端到端进行训练优化的迁移算法,主要包括基于统计、基于对抗和基于重构等方法的对齐。具体包括以下算法:

  • DDC (CoRR2014)

  • DAN (ICML2015)

  • DaNN (ICML2015)

  • DeepCoral (ECCV2016)

  • DRCN (ECCV2016)

  • CoGAN (NeurIPS2016)

  • RTN (NeurIPS2016)

  • JAN (ICML2017)

  • ADDA (CVPR2017)

  • CDAN (NeurlPS2018)

  • MCD (CVPR2018)

  • GTA (CVPR2018)

浅层迁移算法传送门:

李新春:Understanding Transfer Learning (A)zhuanlan.zhihu.com2a1bd579cb6068335a289a40245a8600.png

深度迁移算法DDC、DAN、DaNN传送门:

李新春:Understanding Transfer Learning (B)zhuanlan.zhihu.com788a7a7cfa58d7e15115f95ea7abdcfd.png

接下来的部分就针对以上列举的深度算法进行介绍。

01

DeepCoral (ECCV2016)

ed9742ccc40cd14a70da8c26967ab373.gif
eecccc5f8940d475bea329aa2255da7d.png

这篇文章基于浅层迁移算法CORAL进行应用到深度网络,CORAL主要是对齐数据的二阶协方差矩阵。因此,DeepCoral的主要思想是引入了下面的Coral损失:

3b892f3e-b957-eb11-8da9-e4434bdf6706.svg

其中 3e892f3e-b957-eb11-8da9-e4434bdf6706.svg 是源域和目标域在要对齐层的特征的协方差矩阵:

3f892f3e-b957-eb11-8da9-e4434bdf6706.svg

02

DRCN (ECCV2016)

ed9742ccc40cd14a70da8c26967ab373.gif
e09e783e9670e841d7df185b3c72af53.png

DRCN是基于重构的迁移方法,类似于多任务学习的方法,源域做的是分类任务,目标域由于没有标签,选择做重构任务(基于AutoEncoder框架),两个任务共享特征提取器,得到的优化过程如下:

46892f3e-b957-eb11-8da9-e4434bdf6706.svg

f87837a1b134df83ac1e6cf94dae0c87.png
03

CoGAN (NeurIPS2016)

ed9742ccc40cd14a70da8c26967ab373.gif
0b05be5fe65da97195b7a721cba234a0.png

CoGAN采用生成式模型进行迁移,理解这篇文章的做法具体包括两个步骤。第一步是弄清楚Coupled GAN的操作方法,示意图如下:

fd9fd92925016a2c0cdb8e5800df870c.png

其中,CoGAN包括两个并行的GAN,即源域和目标域上分别有一个GAN网络来生成数据,包括Generator和Discriminator,为了共享源域和目标域的信息,两个GAN网络共享了Generator和Discriminator部分层的权重(Weight Sharing)。

第二步是弄懂这个网络如何进行迁移呢?在应用到迁移任务的时候,源域的GAN需要加入额外的分类损失,在Source Discriminator后面接一层分类器Classifier,一起训练,在预测Target样本的时候,通过Target Discriminator和Classifier进行分类。由于Target Discriminator前几层是和源域解耦开来的,可以提取域相关的特征,后几层是和源域共享的,是高层语义的提取和分类工作,因此可以工作地很好。

这个工作的核心在于两个GAN的共享层和私有层,因此文章对共享什么层,共享多少层进行了评测:

3fb597969e01a5e6969af7438da3babb.png
04

RTN (NeurIPS2016)

ed9742ccc40cd14a70da8c26967ab373.gif
a7ee82dfcbe70cd4b9ff34bda1be61b1.png

RTN是Residual Transfer Networks的缩写,也是基于统计的对齐方法,它主要解决的问题是同时考虑Classifier在源域和目标域上存在差异(之前的工作大多都是直接将源域的Classifier放到目标域分类)。这篇工作主要包括三个部分:基于Bottleneck联合分布和MMD的特征对齐、基于残差分类器的分类器对齐、目标域分类器的信息熵正则化项。

由于Backbone的参数量比较大且输出维度比较高(比如AlexNet的4096维和ResNet50的2048维),加入了一层Bottleneck层,记作 67892f3e-b957-eb11-8da9-e4434bdf6706.svg ,然后到最后一层分类器 6c892f3e-b957-eb11-8da9-e4434bdf6706.svg ,为了对齐数据特征,可以采用类似DAN的方法采用多层MK-MMD的损失,但是这里采用了 67892f3e-b957-eb11-8da9-e4434bdf6706.svg 和 6c892f3e-b957-eb11-8da9-e4434bdf6706.svg 张量内积进行使用MMD对齐:

76892f3e-b957-eb11-8da9-e4434bdf6706.svg

7a892f3e-b957-eb11-8da9-e4434bdf6706.svg

其次,为了保持源域和目标域分类器层的差异,引入了残差分类器,即目标域样本直接根据 80892f3e-b957-eb11-8da9-e4434bdf6706.svg 进行分类,而源域则通过 83892f3e-b957-eb11-8da9-e4434bdf6706.svg 进行分类,在实现过程就当做源域的分类器多了几层。

最后为了因为源域目标域不一致带来的分类性能的下降,在目标域的分类过程引入了信息熵作为监督信息,即: 88892f3e-b957-eb11-8da9-e4434bdf6706.svg

最后的网络结构如下,包含了上面的几个模块:

454ac7fd2842212005575a77988df87d.png
05

JAN (ICML2017)

ed9742ccc40cd14a70da8c26967ab373.gif
c4797976147fae1f56abc9e79017cea4.png

JAN是Joint Adaptation Network,这里的Joint指的并不是数据和标签的联合分布,而是指的多层隐含层的特征的联合分布。DAN使用单独的MK-MMD对多层特征进行分别对齐,这里JAN将层与层之间的关系建模,并以MMD的方式进行对齐。

这里需要一些Kernel Embedding方面的知识,可以参考下面的传送门:

李新春:Deep Kernel Density Estimationzhuanlan.zhihu.com6372605702c2fd424284d3178e8fe7c4.png李新春:Kernel Distribution Embeddingzhuanlan.zhihu.com9f5b846e6527379e52c6dfb59d15eed4.png

再稍微回忆一下Kernel Embedding的东西:

93892f3e-b957-eb11-8da9-e4434bdf6706.svg  是定义在 94892f3e-b957-eb11-8da9-e4434bdf6706.svg 上的一个希尔伯特函数空间(Hilbert Space Function Spase), 93892f3e-b957-eb11-8da9-e4434bdf6706.svg 的再生核是 99892f3e-b957-eb11-8da9-e4434bdf6706.svg ,需要满足两个条件:1)对于任意的 9c892f3e-b957-eb11-8da9-e4434bdf6706.svg 来说 9f892f3e-b957-eb11-8da9-e4434bdf6706.svg ;2) 对于任意的 a1892f3e-b957-eb11-8da9-e4434bdf6706.svg 以及 a6892f3e-b957-eb11-8da9-e4434bdf6706.svg ,有再生性质 ac892f3e-b957-eb11-8da9-e4434bdf6706.svg .

其中最重要的一个性质是(这个式子将内积、希尔伯特空间、核函数三个概念完美地融合在了一起):

af892f3e-b957-eb11-8da9-e4434bdf6706.svg

单一变量 b0892f3e-b957-eb11-8da9-e4434bdf6706.svg 的Kernel Distribution Embedding指的是:

b1892f3e-b957-eb11-8da9-e4434bdf6706.svg

其有限样本估计的经验公式为:

b4892f3e-b957-eb11-8da9-e4434bdf6706.svg

当变量扩展到 b6892f3e-b957-eb11-8da9-e4434bdf6706.svg 时,且每个变量都是一个高维的张量,根据张量积空间(Tensor Product Space)的性质有这 b9892f3e-b957-eb11-8da9-e4434bdf6706.svg 对应的Kernel Distribution Embedding为:

ba892f3e-b957-eb11-8da9-e4434bdf6706.svg

其对应的经验统计式子为:

bc892f3e-b957-eb11-8da9-e4434bdf6706.svg

然后就可以计算MMD损失,回顾一下单变量的MMD损失为:

c0892f3e-b957-eb11-8da9-e4434bdf6706.svg

其经验统计的结果为:

c2892f3e-b957-eb11-8da9-e4434bdf6706.svg

扩展到多变量的联合分布即可以得到JMMD:

c3892f3e-b957-eb11-8da9-e4434bdf6706.svg

其对应的经验公式为:

c6892f3e-b957-eb11-8da9-e4434bdf6706.svg

然后是Adversarial JMMD的想法,MMD对 c9892f3e-b957-eb11-8da9-e4434bdf6706.svg 要求比较高,因为上面的公式涉及到了 ca892f3e-b957-eb11-8da9-e4434bdf6706.svg ,而传统的做法是基于一个Universal Kernel来做,认为Universal Kernel已经足够好。本文则提出了类似于DaNN对抗的思想来丰富 cb892f3e-b957-eb11-8da9-e4434bdf6706.svg 的表示能力,主要思想是对特征引入一层变化层(参数记作 ce892f3e-b957-eb11-8da9-e4434bdf6706.svg )进行对抗。JAN和JAN-A的示意图如下:

e37900bc4866b306423df2a93da1cb5f.png
06

ADDA (CVPR2017)

ed9742ccc40cd14a70da8c26967ab373.gif
92818952a7047965a38f49627870da9f.png

ADDA这篇工作首先整合了一个DA的框架,具体包括:

  • 特征提取器是生成式还是判别式:比如DRCN、CoGAN都是基于生成模型建模的,属于生成式,而DaNN则是基于判别式的;

  • 特征提取器在源域和目标域式共享还是解耦:比如DaNN、DRCN都是共享的,CoGAN共享了部分层,归为了不共享的类;

  • 使用的对抗损失:DaNN使用梯度反转层,即Minimax优化,Domain Confusion使用了confusion的损失(优化Generator使得Domain Classifier将目标域预测的概率分为均匀分布),CoGAN使用了真/假的Discriminator损失,也属于GAN Loss。

53d69af50dfb529cccf854e37a225177.png
731055ad56df339086e59ae02cb26c21.png

这里解释一下三种Loss。第一种是基于Minimax的损失,使用梯度反转层:

da892f3e-b957-eb11-8da9-e4434bdf6706.svg

第二种使用GAN Loss,则是优化Generator (G)使得在目标域上提取的特征,让Domain Classifier (D)预测的结果尽可能接近1(源域样本的域标签):

db892f3e-b957-eb11-8da9-e4434bdf6706.svg

第三种使用Confusion Loss,则是优化Generator (G)使得源域、目标域样本的特征通过Domain Classifier分类之后是均匀分布:

dd892f3e-b957-eb11-8da9-e4434bdf6706.svg

ADDA使用的则是:判别式Generator、源域和目标域使用不一样的Generator、GAN损失。

1b53d22da3da845a7e606723c1aed54c.png

ADDA的训练流程如上:第一步,在源域训练特征提取器(Source CNN)和分类器(Classifier);第二步,将源域特征提取器拷贝一份到目标域当作Target CNN,然后固定住Source CNN,单向训练Target CNN使得Target CNN提取的特征向源域对齐,具体训练过程是和使用GAN Loss的Discriminator进行对抗;第三步,将目标域CNN和源域的Classifier放一起,进行预测。

07

CDAN (NeurlPS2018)

ed9742ccc40cd14a70da8c26967ab373.gif
5aeb86c50f9217bea6b284b7eaa69f22.png

CDAN主要是将数据的多模态(multimodal)性质考虑进去,指的是数据对齐的时候不同类别的源域和目标域可能会对齐到一起,此时虽然基于对抗或者统计的方法已经无法区分源域和目标域特征,但是实际上却是源域类别A的数据和目标域B的数据对齐,导致最后分类性能变差。

记特征提取器提取的特征为 e5892f3e-b957-eb11-8da9-e4434bdf6706.svg ,分类器得到的输出(logits)为 e6892f3e-b957-eb11-8da9-e4434bdf6706.svg 。CDAN主要是将特征和输出(logits)联合在一起进行对齐,使用Multilinear Map(RTN和JAN里面使用的张量积):

e9892f3e-b957-eb11-8da9-e4434bdf6706.svg

一般而言, ea892f3e-b957-eb11-8da9-e4434bdf6706.svg 是两个向量形式的张量,因此 eb892f3e-b957-eb11-8da9-e4434bdf6706.svg ,即特征的每一个维度的值都会和预测的输出(logits)相乘。假设 ed892f3e-b957-eb11-8da9-e4434bdf6706.svg 是一个One-hot的向量,那么有以下好处(不仅仅是 cb892f3e-b957-eb11-8da9-e4434bdf6706.svg 和 ed892f3e-b957-eb11-8da9-e4434bdf6706.svg 简单的拼接 f1892f3e-b957-eb11-8da9-e4434bdf6706.svg , f2892f3e-b957-eb11-8da9-e4434bdf6706.svg 是拼接算子):

f5892f3e-b957-eb11-8da9-e4434bdf6706.svg

下面就是使用对抗的方法进行训练:

f6892f3e-b957-eb11-8da9-e4434bdf6706.svg

上面的优化和传统DaNN不一样的地方就是在于使用的不仅仅是特征 cb892f3e-b957-eb11-8da9-e4434bdf6706.svg ,而是使用的 fa892f3e-b957-eb11-8da9-e4434bdf6706.svg ,而 fa892f3e-b957-eb11-8da9-e4434bdf6706.svg 的定义是两种:当 fd892f3e-b957-eb11-8da9-e4434bdf6706.svg 时,即维度比较小时,直接使用 008a2f3e-b957-eb11-8da9-e4434bdf6706.svg ,否则使用带有Randomized Projection的近似来减小运算开销,具体为 028a2f3e-b957-eb11-8da9-e4434bdf6706.svg , 038a2f3e-b957-eb11-8da9-e4434bdf6706.svg 是每个元素满足均值为0,方差为1的随机映射矩阵, 068a2f3e-b957-eb11-8da9-e4434bdf6706.svg 表示点乘。

最后,为了防止一些比较难的样本在对抗训练中起到不好的作用(比如强行将分类效果不好的样本进行对齐会产生负作用),这里引入了一种基于熵降低难样本权重的方法(CDAN-E):

088a2f3e-b957-eb11-8da9-e4434bdf6706.svg

最后CDAN的示意图如下:

d513bbdf9d10050c6a86966165a2dcf7.png
08

MCD (CVPR2018)

ed9742ccc40cd14a70da8c26967ab373.gif
22bb5ac153905a41a7128fb154dbe756.png

传统的基于对抗的迁移,比如DaNN、ADDA等,都是通过一个显式的二分类器来当作Domain Classifier (Discriminator),这样做的结果可能会使得源域和目标域特征对齐的同时,源域的类别特征区分度会下降。MCD主要是通过两个分类器的差异(Discrepancy)来取代之前的Domain Classifier。首先两个分类器的差异(Discrepancy)为:

0c8a2f3e-b957-eb11-8da9-e4434bdf6706.svg

训练步骤如下:

第一步,训练源域的特征提取器和两个分类器,使其性能尽可能好;

第二步,固定住特征提取器,训练两个分类器,使得两个分类器的差异尽可能大,示意图如下(这里训练的时候仍然需要加上源域分类损失):

3d67ff993aafd9a8e87c97daf00544c9.png

第三步,固定住两个分类器,训练特征提取器,使得两个分类器的差异尽可能小,示意图如下:

64959959f726b6a11fb9f6da5f11eb3e.png

这个过程比较绕,仔细分析一下下面的训练步骤:

a9478e8a70ef445d2a59e3741149676b.png

图中需要注意的是:虚线椭圆代表的是源域样本,实线椭圆代表目标域;灰色代表类别A,棕色代表类别B。

Step 1:最左边图,两个分类器示意图,此时源域、目标域没有对齐(虚线和实线椭圆没有对齐),并且分类器可以较好地将源域两个类分开(虚线椭圆的灰色和棕色),但是由于目标域特征和源域特征存在差异(没有对齐),目标域上有误差(实线椭圆的灰色和棕色),阴影部分代表的是两个分类器分类不一致的地方,即Discrepancy;

Step 2:第二个图,最大化分类器差异,此时特征提取器固定住,椭圆的位置没有变化,分类器一方面要保证在源域上的性能,另一方面要保证Discrepancy最大,因此分类器会向源域类别的边界尽可能靠拢(向虚线椭圆的边界靠拢),此时目标域上样本的阴影变大(Discrepancy变大);

Step 3:第三个图,最小化分类器差异,此时分类器固定住,椭圆的位置在变化,为了最小化分类器差异,特征提取器提取的特征会使得目标域特征尽可能向分类正确的方向靠拢,因为第二步指出了“分类器会向源域类别的边界尽可能靠拢”,因此间接地使得源域和目标域样本对齐,且保证了类别之间的正确对齐;

Step 4:最后的效果图。

上述步骤只是理想的图示,实际情况比这个复杂得多,但是MCD的思想很好,且实际效果不错,相比较于DaNN这些方法,最后的效果示意图为:

f564dff63757d2dad799fd144678af51.png
09

GTA (CVPR2018)

ed9742ccc40cd14a70da8c26967ab373.gif
306f7ebf22304c72aef846a66491215a.png

GTA采用生成式方法对源域和目标域数据进行对齐,类似于DRCN和CoGAN的思想,对源域样本采用分类损失,同时对源域和目标域的特征进行生成式建模(利用Generator和Discriminator),这里的Discriminator采用了AC-GAN的思想,即预测数据是生成还是真实(Fake/Real)的同时,去预测这个图片的类别,因此称为Auxiliary Classifier GAN,框架图如下:

e06714b2ef792cb2d56cbfcc768fb1e0.png

上图包含特征提取器F、标签预测器C、生成器G、判别器D,其作用分别如下:

  • 特征提取器F:对源域和目标域提取特征;

  • 标签预测器C:预测源域样本数据的类别,即Classifier;

  • 生成器G:根据特征提取器提取的特征生成数据;

  • 判别器D:判别数据是真实数据还是生成的数据,同时对数据进行分类。

具体算法流程和每个部分更新的损失如下:

f1d72b42e59d0ea8f09b0509a457efc5.png

Discriminator包括两部分: 1c8a2f3e-b957-eb11-8da9-e4434bdf6706.svg ,前者对数据真实/伪造进行分类,促进生成的样本和真实样本分布一致,后者是Auxiliary Classifier,使得生成的数据也尽可能地被正确分类,更进一步促进了生成数据的真实性。

Discriminator的训练包括三个损失:1) 1d8a2f3e-b957-eb11-8da9-e4434bdf6706.svg 尽可能判别真实源域数据和通过源域特征 1e8a2f3e-b957-eb11-8da9-e4434bdf6706.svg 生成的数据;2) 1f8a2f3e-b957-eb11-8da9-e4434bdf6706.svg 尽可能对真实源域样本进行分类正确;3) 208a2f3e-b957-eb11-8da9-e4434bdf6706.svg 尽可能对目标域特征 228a2f3e-b957-eb11-8da9-e4434bdf6706.svg 生成的数据预测为合成的数据。

Generator的训练只在源域上和Discriminator对抗,使得生成的数据也尽可能通过 238a2f3e-b957-eb11-8da9-e4434bdf6706.svg 分类正确(根据Discriminator的第2个损失,只有分布和真实源域样本特别像的样本才会使得 238a2f3e-b957-eb11-8da9-e4434bdf6706.svg 分类损失小);其次,Generator会尽可能使得生成的数据 278a2f3e-b957-eb11-8da9-e4434bdf6706.svg 被Discriminator判别为真实数据。

Feature Extractor的训练包括三部分:1) 288a2f3e-b957-eb11-8da9-e4434bdf6706.svg 源域样本分类效果尽可能好;2) 1f8a2f3e-b957-eb11-8da9-e4434bdf6706.svg 特征提取器提取的源域样本特征,通过生成器G,通过Discriminator 238a2f3e-b957-eb11-8da9-e4434bdf6706.svg 可以被正确分类,目的还是希望提取的特征具有区分度,本质还是分类损失;3) 2e8a2f3e-b957-eb11-8da9-e4434bdf6706.svg 作为和Discriminator第三个Loss对抗的损失,目的是使得特征提取器在目标域提取的特征,通过生成器G,得到的图像被Discriminator预测为真实图像。由于Generator只在源域上训练,即:只有当目标域提取的特征和源域提取的特征很像时(F的输出特征得到了对齐),Generator才可以生成真实数据。

总的来说,GTA比较复杂,以及训练的时候Loss很多,很容易搞混。但是思想上就是通过生成数据的过程将源域和目标域提取的特征对齐。

参考文献:

  • Baochen Sun, Kate Saenko: Deep CORAL: Correlation Alignment for Deep Domain Adaptation. ECCV Workshops 2016.

  • Muhammad Ghifary, et al: Deep Reconstruction Classification Networks for Unsupervised Domain Adaptation. ECCV 2016.

  • Ming-Yu Liu, Oncel Tuzel: Coupled Generative Adversarial Networks. NeurIPS 2016.

  • Mingsheng Long, Han Zhu, Jianmin Wang, Michael I. Jordan: Unsupervised Domain Adaptation with Residual Transfer Networks. NeurIPS 2016.

  • Mingsheng Long, Han Zhu, Jianmin Wang, Michael I. Jordan: Deep Transfer Learning with Joint Adaptation Networks. ICML 2017.

  • Eric Tzeng, Judy Hoffman, Kate Saenko, Trevor Darrell: Adversarial Discriminative Domain Adaptation. CVPR 2017.

  • Mingsheng Long, Zhangjie Cao, Jianmin Wang, Michael I. Jordan: Conditional Adversarial Domain Adaptation. NeurIPS 2018.

  • Kuniaki Saito, Kohei Watanabe, Yoshitaka Ushiku, Tatsuya Harada: Maximum Classifier Discrepancy for Unsupervised Domain Adaptation. CVPR 2018.

  • Swami Sankaranarayanan, Yogesh Balaji, et al: Generate to Adapt: Aligning Domains Using Generative Adversarial Networks. CVPR 2018.

往期系列文章:

伪文艺程序员的“迁移学习”啃读(A)

伪文艺程序员的“迁移学习”啃读(B)

伪文艺程序员的“迁移学习”啃读(C)

伪文艺程序员的“迁移学习”啃读(D)

伪文艺程序员的“迁移学习”啃读(E)

伪文艺程序员的“迁移学习”啃读(F)

358152d487002b3073f5f7a4530e1131.png

 网络人工智能园地,力求打造网络领域第一的人工智能交流平台,促进华为iMaster NAIE理念在业界(尤其通信行业)形成影响力!

9aa2ea1bb2640689357805f896d06571.png

cc2c2fd65aed65fa04539c0c83875057.png

358152d487002b3073f5f7a4530e1131.png

1c3ae3ec54bd86bfbca8c83a0f01a818.png

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值