Transferrable Prototypical Networks for Unsupervised Domain Adaptation读书笔记

这篇文章是cvpr2019的一篇域适应的文章,我觉得很不错,和大家一起分享一下
在这篇文章里,通过对每个类到原型的距离的重新映射,学习一个嵌入空间并执行分类。具体地说,提出了可转移的原型网络(TPN)来进行自适应,使源域和目标域的每个类的原型在嵌入空间上都很接近,并且原型在源和目标数据上分别预测的分数分布也很简单。从技术上讲,TPN最初将每个目标示例与源域中最近的原型匹配,并将这个原型作为示例的“伪”标签。这样,每个类的原型就可以分别在源数据、目标数据和源-目标数据上进行计算。TPN的优化是通过三种类型数据上的样本间的方差和每对样本输出的分数分布的KL-散度的联合最小化来进行端到端训练的。
总体流程图如下所示
在这里插入图片描述TPN的实现流程:1.首先将输入样例通过神经网络嵌入到高维映射空间里,这种高维表示表征的是跨域的域不变特征。特别地,每个batch里包含源域与目标域样例,首先计算源域样例各个类别的原型,再通过计算各个目标域样例与各类别原型之间的距离,距离最近的就作为这个目标域样例的“伪”标签(“pseudo” label),
2.然后制定通用自适应,以使源数据上测量的原型、带有伪标签的目标数据和源-目标数据之间的距离最小化。这是为了减轻类层面上的域差异。在特定任务上的自适应,利用每个示例嵌入softmax输出到原型类型的距离大小为依据作为特定任务分类器。利用KL-散度,通过分类器对各个领域或它们的组合计算出的原型的不匹配度进行分类。整个TPN是端到端的训练,通过最小化标记源数据的分类损失和两个自适应项,并将学习从一个批次切换到另一个批次。
总的来说,通过最小化多粒度域离散度,并利用未标记的目标数据和标记的源数据构造分类器,将原型网络重新塑造成非监督域自适应的场景。

那么,什么是原型网络(Prototypical Networks)呢?
这个思想来自于小样本学习里面的一篇文章《Prototypical Networks for Few-shot Learning》,这里通过神经网络学会一个“好的”映射,将各个样本投影到同一空间中,对于每种类型的样本提取他们的中心点(mean)作为原型(prototype)。使用欧几里得距离作为距离度量,训练使得测试样本到自己类别原型的距离越近越好,到其他类别原型的距离越远越好。这里的原型与之相类似,即对每一类样本都以一个中心点作为原型,再通过使源域,目标域间相同类别的原型尽可能近,不同类别间的原型尽可能远来实现。
本文中,每一类的原型定义为:
在这里插入图片描述

f(x;θ)表征的是样例的高维嵌入表示。给定指定样例Xi,原型网络通过一个到原型距离输出通过softmax函数,直接生成其得分分布向量,其中第c元素表示隶属于第c类的概率。
在这里插入图片描述

d(.)表征的是所计算样例与源域各个原型之间的距离函数。原型网络的计算是通过最小化样例对应类别的log-likelihood实现。

本文最主要的创新就是只使用了原型网络,这种框架通过在每个类的原型上构造分类器,很自然地将特征和分类器的学习统一到一个网络中。这个设计反映了一个非常简单的归纳偏见,这是有益的领域适应机制。具体来说,为了使原型网络能够跨域传输,设计了两种自适应机制,通过减少多粒度来调整源域和目标域的分布域不变性。

总体实施方案:
1.总体域适应
首先通过原型网络给最近的目标域数据打标签,使所有目标域数据都有了标签,从而可以在源域,目标域,源-目标混合域上训练三种以原型为基础的分类网络。
在这里插入图片描述
上面三个量依次表示了三种情况下希尔伯特核空间里面的每一类的原型值。
具体做法是比较不同域间同类别原型值的差异,基本思想是,如果源域和目标域的数据分布相同,则在不同域中实现的同一类的原型是相同的。故其表达式仿照MMD距离公式,为
在这里插入图片描述
通过最小化这一项,在每个域中计算的每个类的原型将被强制在嵌入空间中非常接近,从而使跨域的域不变表示分布可以被拉近。
与MMD距离比较可以发现,MMD表示为跨域的完整原型类型之间的RKHS距离。我们的类级域差异(与MMD不同)被计算为来自不同域的每个类的原型之间的RKHS距离。换句话说,源域与目标域数据分布的细粒度对齐是在类级别上执行的,而不是简单地最小化跨域的完整原型之间的距离,所以更加细致。
2.根据任务的域适应
当源和目标分布很好地对齐时,每个源域/目标域样本应该由特定于任务的分类器正确分类,从而导致跨域分类器的决策一致。为了测量样本水平域的离散性,我们利用KL-散度来计算不同主元的得分分布之间的成对距离。将源样本和目标样本的样本级差异损失定义为:
在这里插入图片描述
通过最小化这个损失,拉近样例间的差异。
3.总损失函数
在这里插入图片描述
第一项是分类损失函数,第二项是类间差异损失,第三项是样例间差异损失。
后面就是实验部分,不再表述。

  • 5
    点赞
  • 20
    收藏
    觉得还不错? 一键收藏
  • 6
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 6
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值