Deep Cocktail Network

Deep Cocktail Network

在这里插入图片描述

1.motivation

domain adaptation是由于获得大量的标注是一件耗时的工作,希望能通过利用已经有标注的source数据集来提升网络在没有标注的target数据集上的表现.本文的出发点是希望使用多个source数据集来进行domain adaptation。multi-source domain adaptation作者认为主要存在两个问题**1)domain shift:**包括源域和目标域,以及不同目标域之间;2) category-shift: 目标域的标签不一定完全一样。受目标域的概率分布可以由多个源域概率分布加权来表示,作者提出了deep cocktail network(DCTN)来解决上述问题。

在这里插入图片描述

2.method

在这里插入图片描述

(1) overview

网络结构主要包含三个部分 Feature Extractor(共享参数)将图像空间映射到特征空间;Domain Discrinator,一共有N个,N为源域的个数,即每一个源域都要和穆标宇训练一个分类器;Category Classifier,也是N个,每一个源域有一个category分类器。

以及最后的target classification operator,利用前面N个category 分类器来的分类loss进行weighted combination.

(2)Feature extractor

F为特征提取,将所有的图片映射到特征空间F(x), x表示输入图片

(3)Domain Discriminator

{ D s j } j = 1 N \{D_{s_j}\}_{j=1}^N {Dsj}j=1N表示N个discriminator, D s j D_{s_j} Dsj用来区分F(x)是来自源域还是目标域

S c f ( x t ; F , D s j ) = − l o g ( 1 − D s j ( F ( x t ) ) ) + α s j S_{cf}(x^t;F,D_{s_j}) =-log(1-D_{s_j}(F(x^t))) + \alpha_{s_j} Scf(xt;F,Dsj)=log(1Dsj(F(xt)))+αsj

其中 α s j \alpha_{s_j} αsj是source-specific concentration constant表示第j个discriminator在源域 X s j X_{sj} Xsj的平均值。 S c f ( x t ; F , D s j ) S_{cf}(x^t;F, D_{s_j}) Scf(xt;F,Dsj)是target-source perplexity scores 就是分类的loss,用作后续的加权平均, α \alpha α 作用是相当于进行了normalize.

(4)category classifier

(5)Target classification operator
C o n f i d e n c e ( c ∣ x t ) : = ∑ c ∈ C s j S c f ( x t ; F , D s j ) ∑ c ∈ C s k S c f ( x t ; F , D s k ) C s j ( c ∣ F ( x t ) ) Confidence(c|x^t):=\sum_{c\in C_{s_j}} \frac{S_{cf}(x^t;F,D_{s_j})}{\sum_{c\in C_{s_k}}S_{cf}(x^t;F, D_{s_k})}C_{s_j}(c|F(x^t)) Confidence(cxt):=cCsjcCskScf(xt;F,Dsk)Scf(xt;F,Dsj)Csj(cF(xt))

3. training

m i n F m a x D V ( F , D ; C ˉ ) = L a d v ( F , D ) + L c l s ( F , C ˉ ) min_{F}max_{D} V(F, D;\bar{C}) = L_{adv}(F, D) + L_{cls}(F,\bar{C}) minFmaxDV(F,D;Cˉ)=Ladv(F,D)+Lcls(F,Cˉ)

其中
L a d v ( F , D ) = 1 N ∑ j N E x ∼ X s j [ l o g D s j ( F ( x ) ) ] + E x t ∼ X t [ l o g ( 1 − D s j ( F ( x t ) ) ] L_{adv}(F, D) = \frac{1}{N}\sum_j^N E_{x\sim X_{s_j}}[log D_{s_j}(F(x))]+E_{x^t\sim X_t}[log(1-D_{s_j}(F(x^t))] Ladv(F,D)=N1jNExXsj[logDsj(F(x))]+ExtXt[log(1Dsj(F(xt))]
让网络学习分类误差最小,discriminator误差最大这样让source domain和target domain在feature space上混淆了,这样domain shift就变小了

**Online hard domain batch mining **

在这里插入图片描述

一共有N个源域,每个源域sample M个,一共N*M个样本。对N个discriminator中最大的进行更新

Target Discriminative Adaptation

作者认为荣国multi-way adversary,DCTN已经能学习到domain-invariant的特征,但是在target domain的分类能力不行。

为了逼近理想的target分类器,给target domain中的每一个样本打上pseudo labels(就是用之前的网络进行inference),然后在将target domain的数据和source domain的数据联合训练

在这里插入图片描述

因为没有针对target训练一个分类器,所以将target classification error反传到multi source的category classifier. 具体来说,对target的样本 ( x t , y ^ ) (x^t, \hat{y}) (xt,y^) 对源域中含有 y ^ \hat{y} y^ 类别的计算对应的分类loss,求和。
在这里插入图片描述

4 experiment

在这里插入图片描述

  • single best:对每对源域-目标域训练 选择最好的结果
  • source combine:将多个源域合并成单个domain

在这里插入图片描述

为了对比对category shift的效果设计了两种不同的category模式overlap(source类别之间是有交集的) disjoint(类别之间完全是没有交集的)

在这里插入图片描述
可以看到DTCN保证了domain的相似性以及类别之间的区分性

  • 0
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值