Deep Cocktail Network
1.motivation
domain adaptation是由于获得大量的标注是一件耗时的工作,希望能通过利用已经有标注的source数据集来提升网络在没有标注的target数据集上的表现.本文的出发点是希望使用多个source数据集来进行domain adaptation。multi-source domain adaptation作者认为主要存在两个问题**1)domain shift:**包括源域和目标域,以及不同目标域之间;2) category-shift: 目标域的标签不一定完全一样。受目标域的概率分布可以由多个源域概率分布加权来表示,作者提出了deep cocktail network(DCTN)来解决上述问题。
2.method
(1) overview
网络结构主要包含三个部分 Feature Extractor(共享参数)将图像空间映射到特征空间;Domain Discrinator,一共有N个,N为源域的个数,即每一个源域都要和穆标宇训练一个分类器;Category Classifier,也是N个,每一个源域有一个category分类器。
以及最后的target classification operator,利用前面N个category 分类器来的分类loss进行weighted combination.
(2)Feature extractor
F为特征提取,将所有的图片映射到特征空间F(x), x表示输入图片
(3)Domain Discriminator
{ D s j } j = 1 N \{D_{s_j}\}_{j=1}^N {Dsj}j=1N表示N个discriminator, D s j D_{s_j} Dsj用来区分F(x)是来自源域还是目标域
S c f ( x t ; F , D s j ) = − l o g ( 1 − D s j ( F ( x t ) ) ) + α s j S_{cf}(x^t;F,D_{s_j}) =-log(1-D_{s_j}(F(x^t))) + \alpha_{s_j} Scf(xt;F,Dsj)=−log(1−Dsj(F(xt)))+αsj
其中 α s j \alpha_{s_j} αsj是source-specific concentration constant表示第j个discriminator在源域 X s j X_{sj} Xsj的平均值。 S c f ( x t ; F , D s j ) S_{cf}(x^t;F, D_{s_j}) Scf(xt;F,Dsj)是target-source perplexity scores 就是分类的loss,用作后续的加权平均, α \alpha α 作用是相当于进行了normalize.
(4)category classifier
(5)Target classification operator
C
o
n
f
i
d
e
n
c
e
(
c
∣
x
t
)
:
=
∑
c
∈
C
s
j
S
c
f
(
x
t
;
F
,
D
s
j
)
∑
c
∈
C
s
k
S
c
f
(
x
t
;
F
,
D
s
k
)
C
s
j
(
c
∣
F
(
x
t
)
)
Confidence(c|x^t):=\sum_{c\in C_{s_j}} \frac{S_{cf}(x^t;F,D_{s_j})}{\sum_{c\in C_{s_k}}S_{cf}(x^t;F, D_{s_k})}C_{s_j}(c|F(x^t))
Confidence(c∣xt):=c∈Csj∑∑c∈CskScf(xt;F,Dsk)Scf(xt;F,Dsj)Csj(c∣F(xt))
3. training
m i n F m a x D V ( F , D ; C ˉ ) = L a d v ( F , D ) + L c l s ( F , C ˉ ) min_{F}max_{D} V(F, D;\bar{C}) = L_{adv}(F, D) + L_{cls}(F,\bar{C}) minFmaxDV(F,D;Cˉ)=Ladv(F,D)+Lcls(F,Cˉ)
其中
L
a
d
v
(
F
,
D
)
=
1
N
∑
j
N
E
x
∼
X
s
j
[
l
o
g
D
s
j
(
F
(
x
)
)
]
+
E
x
t
∼
X
t
[
l
o
g
(
1
−
D
s
j
(
F
(
x
t
)
)
]
L_{adv}(F, D) = \frac{1}{N}\sum_j^N E_{x\sim X_{s_j}}[log D_{s_j}(F(x))]+E_{x^t\sim X_t}[log(1-D_{s_j}(F(x^t))]
Ladv(F,D)=N1j∑NEx∼Xsj[logDsj(F(x))]+Ext∼Xt[log(1−Dsj(F(xt))]
让网络学习分类误差最小,discriminator误差最大这样让source domain和target domain在feature space上混淆了,这样domain shift就变小了
**Online hard domain batch mining **
一共有N个源域,每个源域sample M个,一共N*M个样本。对N个discriminator中最大的进行更新
Target Discriminative Adaptation
作者认为荣国multi-way adversary,DCTN已经能学习到domain-invariant的特征,但是在target domain的分类能力不行。
为了逼近理想的target分类器,给target domain中的每一个样本打上pseudo labels(就是用之前的网络进行inference),然后在将target domain的数据和source domain的数据联合训练
因为没有针对target训练一个分类器,所以将target classification error反传到multi source的category classifier. 具体来说,对target的样本
(
x
t
,
y
^
)
(x^t, \hat{y})
(xt,y^) 对源域中含有
y
^
\hat{y}
y^ 类别的计算对应的分类loss,求和。
4 experiment
- single best:对每对源域-目标域训练 选择最好的结果
- source combine:将多个源域合并成单个domain
为了对比对category shift的效果设计了两种不同的category模式overlap(source类别之间是有交集的) disjoint(类别之间完全是没有交集的)
可以看到DTCN保证了domain的相似性以及类别之间的区分性