DSAN
深度子领域迁移网络
总结:将源域/目标域的各子领域(根据标签划分)分别进行对齐,而非全局对齐。
以往大多数深度域自适应方法都进行全局对齐(不考虑任何细粒度的信息,比如标签),会混淆源域/目标域的数据,丢失每个类别的细粒度信息。
本文提出了一种深度子领域自适应网络(DSAN),基于局部最大均值差异(LMMD),通过对齐源域/目标域上域特定层(domain-specific layer)激活的相关子领域分布来学习迁移网络。(不需要收敛缓慢的对抗性训练)
Related Work
- Domain Adaptation
- Maximum Mean Discrepancy(MMD)
TCA、DDC、DAN、[27]
扩展:条件MMD(JDA)、联合MMD(JAN)、加权MMD[32]、CMMD(JDA、[23]、[33];LMMD的一个特例) - Subdomain Adaptation
子领域自适应,也叫做语义对齐(semantic alignment)或匹配条件概率分布(matching conditional distribution,对齐全局特征实际上是匹配边缘概率分布)。
以下方法均采用对抗损失,复杂,有多个损失,且收敛缓慢。NIPS-2018 Conditional adversarial domain adaptation(CDAN):根据分类器预测中传递的判别信息对对抗自适应模型进行约束。
AAAI-2018 Multi-adversarial domain adaptation(MADA):捕获多模式结构,从而基于多域鉴别器实现不同数据分布的细粒度对齐。
NIPS-2018 Co-regularized alignment for unsupervised domain adaptation(Co-DA):构建多个不同的特征空间,在每个特征空间中单独对齐源分布和目标分布,同时鼓励对齐在未标记的目标样本上的类别预测方面彼此一致。
ICML-2018 Learning semantic representations for unsupervised domain adaptation(MSTN):通过对齐有标记的源质心和伪标记的目标质心来学习未标记目标样本的语义表示。
Method
- 源域
D
s
=
{
(
x
i
s
,
y
i
s
)
}
i
=
1
n
s
\mathcal{D}_s=\{(\mathbf{x}_i^s,\mathbf{y}_i^s)\}_{i=1}^{n_s}
Ds={(xis,yis)}i=1ns,其中
y
i
s
∈
R
C
\mathbf{y}_i^s\in\R^C
yis∈RC 是
x
i
s
\mathbf{x}_i^s
xis 的标签(one-hot编码),
C
C
C 是类别数量。
目标域 D t = { ( x j t } j = 1 n t \mathcal{D}_t=\{(\mathbf{x}_j^t\}_{j=1}^{n_t} Dt={(xjt}j=1nt。 - 源域和目标域采样自不同数据分布 p p p 和 q q q( p ≠ q p\ne q p=q)。
Subdomain Adaptation 子领域自适应
以往深度迁移学习方法:使用具有全局域自适应损失
d
^
(
p
,
q
)
\hat{d}(p,q)
d^(p,q) 的自适应层来学习域不变特征。
根据类别,将
D
s
\mathcal{D}_s
Ds 和
D
t
\mathcal{D}_t
Dt 划分为
C
C
C 个子领域
D
s
(
c
)
\mathcal{D}^{(c)}_s
Ds(c) 和
D
t
(
c
)
\mathcal{D}^{(c)}_t
Dt(c) ,其中类标签
c
∈
{
1
,
2
,
…
,
C
}
c\in\{1,2,\dots,C\}
c∈{1,2,…,C},
D
s
(
c
)
\mathcal{D}^{(c)}_s
Ds(c) 和
D
t
(
c
)
\mathcal{D}^{(c)}_t
Dt(c) 的分布为
p
(
c
)
p^{(c)}
p(c) 和
q
(
c
)
q^{(c)}
q(c)(类条件概率分布)。
结合分类损失和子领域自适应损失,子领域自适应方法的损失表述为
Local Maximum Mean Discrepancy(LMMD) 局部MMD
假设每个样本属于类别
c
c
c 的权重为
w
c
w^c
wc,给出(5)的无偏估计:
- w i s c w^{sc}_i wisc 和 w j t c w^{tc}_j wjtc 分别表示 x i s \mathbf{x}_i^s xis 和 x j t \mathbf{x}_j^t xjt 类别 c c c 的权重。
- ∑ i = 1 n s w i s c = ∑ j = 1 n t w j t c = 1 \sum_{i=1}^{n_s}w^{sc}_i=\sum_{j=1}^{n_t}w^{tc}_j=1 ∑i=1nswisc=∑j=1ntwjtc=1(均值)
- ∑ x i ∈ D w i c ϕ ( x i ) \sum_{\mathbf{x}_i\in\mathcal{D}}w^{c}_i \phi(\mathbf{x}_i) ∑xi∈Dwicϕ(xi) 是类别 c c c 的加权和。
样本
x
i
\mathbf{x}_i
xi 的权重
w
i
c
w^{c}_i
wic:
-
y
i
c
y_{ic}
yic 是预测标签向量(one-hot编码)
y
i
\mathbf{y}_i
yi 的第
c
c
c 个元素。
对于源域样本,使用真实标签 y i s \mathbf{y}_i^s yis;对于无标签的目标域样本,使用伪标签 y ^ i = f ( x i ) \hat{\mathbf{y}}_i=f(\mathbf{x}_i) y^i=f(xi)。
深度网络在第
l
l
l 层(
l
∈
L
=
{
1
,
2
,
…
,
∣
L
∣
}
l\in L=\{1,2,\dots,|L|\}
l∈L={1,2,…,∣L∣})生成激活
{
z
i
s
l
}
i
=
1
n
s
\{\mathbf{z}^{sl}_i\}^{n_s}_{i=1}
{zisl}i=1ns 和
{
z
j
t
l
}
j
=
1
n
t
\{\mathbf{z}^{tl}_j\}^{n_t}_{j=1}
{zjtl}j=1nt。另外,由于不能直接计算
ϕ
(
⋅
)
\phi(\cdot)
ϕ(⋅),将(6)写成
Deep Subdomain Adaptation Network
Experiment
- 目标识别:ImageCLEF-DA、Office-31、OfficeHome、VisDA-2017
- 数字分类:MNIST、USPS、SVHN