Transferrable Prototypical Networks for Unsupervised Domain Adaptation (CVPR 2019)

Transferrable Prototypical Networks for Unsupervised Domain Adaptation

论文: Yingwei Pan, Ting Yao, Yehao Li, Yu Wang, Chong-Wah Ngo, Tao Mei. Transferrable Prototypical Networks for Unsupervised Domain Adaptation. CVPR 2019.

##1. Motivation

文章提出了Transferrable representation in Prototypical Networks (TPN).去解决source domain下有标签,target domain下无标签的domain adaptation
分类问题。

文章的思路从分类器的定义出发:

img

Prototypical中分类器的定义: 对于一个样本 x x x的分类方法是通过计算特征空间下 x x x与每个类prototype的距离,进而分类:

所以可以通过更新prototype来灵活地实现跨模态实现分类任务。

首先通过source domain下的数据来预测target domain下的数据标签。在此之后,利用预测出来的pseudo label去学习另外两种基于prototypes的分类器:一个是利用target domain的数据得到的,另外一种通过source+target domain的数据得到。

之后TPN的训练通过以下方式实现:
1)正确地对原域进行分类
2) 以multi-granular(多种粒度,包含:class level (类级) 和 sample level(样本级))的方式减少domain discrepancy

在训练过程中的每个iteration,通过不断进行1)2)两步来用end-to-end的方式优化整个TPN。

类级别的 domain discrepancy 通过对齐每类的prototype来减少。样本级别domain discrepancy通过同步每个样本在不同domain下属于每个类的概率分布来减少。

下图中,右上角的图片指的是将每一类在source domain、target domian、source+ target domain下的prototype对齐。 这就是类别级的对齐。左下角的图片指的是对于每个样本,在不同domain下属于某个类别的概率应该相同。这种概率通过特征空间中样本与某一类prototype之间的距离计算,见之前介绍的Prototypical中分类器的定义

##2. 损失函数

总体损失函数:

其中 L s L_s Ls表示在有标签情况下source domain下的分类损失, L G L_G LG表示General-purpose Domain Adaptation,也就是类别级的损失函数, L T L_T LT表示Task-specific Domain Adaptation,也就是样本级别的损失函数。

###2.1 source domain下的分类损失:

img

###2.2 类别级的loss减少domain discrepancy
(General-purpose Domain Adaptation):

u c s ~ \widetilde{u_c^s} ucs 表示source domian下类别c在希尔伯特空间的映射值,利用与MMD相似的方法,展开然后用核函数表示该式即可。

###2.3 样本级别的loss以减少domain discrepancy
(Task-specific Domain Adaptation)

其中 p i c s p_ic^s pics表示样本 x i x_i xi根据source domain下的prototypes分类属于类别 c c c的概率。

For each sample x i x_i xi,根据$x_i与某一domain下各类prototypes的距离,来计算在这一domain下的分到各类类概率。在概率计算后,对于不同domain下,应该有同一个样本分到不同domain下各类的概率相当。这种近似程度可以用KLDivergence来衡量。对于不同类别累加KLDivergence。 由于KLdis不对称,对换参数位置再/2。

##3. Details

对于每个iteraion,1) 先利用特征提取网络$f去提取每个样本的特征并计算原域的prototype。根据之前提到过的prototypical分类器去给targetdomain下的样本分类,并贴上标签。这样就可以计算target和source+target domain下每类的prototype。

  1. 根据得到的各domain下的prototypes计算损失函数、反向传播更新特征提取网络 f f f

反复交替进行1)2)两步…

其中在贴pseudo标签时,要求每个样本计算出的属于某一类的概率的最大值至少要 > 0.6 \gt 0.6 >0.6,否则对于这个样本就不贴上标签。这样做是为了避免对噪声标签过拟合。

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值