Transferrable Prototypical Networks for Unsupervised Domain Adaptation
论文: Yingwei Pan, Ting Yao, Yehao Li, Yu Wang, Chong-Wah Ngo, Tao Mei. Transferrable Prototypical Networks for Unsupervised Domain Adaptation. CVPR 2019.
##1. Motivation
文章提出了Transferrable representation in Prototypical Networks (TPN).去解决source domain下有标签,target domain下无标签的domain adaptation
分类问题。
文章的思路从分类器的定义出发:
Prototypical中分类器的定义: 对于一个样本 x x x的分类方法是通过计算特征空间下 x x x与每个类prototype的距离,进而分类:
所以可以通过更新prototype来灵活地实现跨模态实现分类任务。
首先通过source domain下的数据来预测target domain下的数据标签。在此之后,利用预测出来的pseudo label去学习另外两种基于prototypes的分类器:一个是利用target domain的数据得到的,另外一种通过source+target domain的数据得到。
之后TPN的训练通过以下方式实现:
1)正确地对原域进行分类
2) 以multi-granular(多种粒度,包含:class level (类级) 和 sample level(样本级))的方式减少domain discrepancy
在训练过程中的每个iteration,通过不断进行1)2)两步来用end-to-end的方式优化整个TPN。
类级别的 domain discrepancy 通过对齐每类的prototype来减少。样本级别domain discrepancy通过同步每个样本在不同domain下属于每个类的概率分布来减少。
下图中,右上角的图片指的是将每一类在source domain、target domian、source+ target domain下的prototype对齐。 这就是类别级的对齐。左下角的图片指的是对于每个样本,在不同domain下属于某个类别的概率应该相同。这种概率通过特征空间中样本与某一类prototype之间的距离计算,见之前介绍的Prototypical中分类器的定义。
##2. 损失函数
总体损失函数:
其中 L s L_s Ls表示在有标签情况下source domain下的分类损失, L G L_G LG表示General-purpose Domain Adaptation,也就是类别级的损失函数, L T L_T LT表示Task-specific Domain Adaptation,也就是样本级别的损失函数。
###2.1 source domain下的分类损失:
###2.2 类别级的loss减少domain discrepancy
(General-purpose Domain Adaptation):
u c s ~ \widetilde{u_c^s} ucs 表示source domian下类别c在希尔伯特空间的映射值,利用与MMD相似的方法,展开然后用核函数表示该式即可。
###2.3 样本级别的loss以减少domain discrepancy
(Task-specific Domain Adaptation)
其中 p i c s p_ic^s pics表示样本 x i x_i xi根据source domain下的prototypes分类属于类别 c c c的概率。
For each sample x i x_i xi,根据$x_i与某一domain下各类prototypes的距离,来计算在这一domain下的分到各类类概率。在概率计算后,对于不同domain下,应该有同一个样本分到不同domain下各类的概率相当。这种近似程度可以用KLDivergence来衡量。对于不同类别累加KLDivergence。 由于KLdis不对称,对换参数位置再/2。
##3. Details
对于每个iteraion,1) 先利用特征提取网络$f去提取每个样本的特征并计算原域的prototype。根据之前提到过的prototypical分类器去给targetdomain下的样本分类,并贴上标签。这样就可以计算target和source+target domain下每类的prototype。
- 根据得到的各domain下的prototypes计算损失函数、反向传播更新特征提取网络 f f f。
反复交替进行1)2)两步…
其中在贴pseudo标签时,要求每个样本计算出的属于某一类的概率的最大值至少要 > 0.6 \gt 0.6 >0.6,否则对于这个样本就不贴上标签。这样做是为了避免对噪声标签过拟合。