Weighted Maximum Mean Discrepancy for Unsupervised Domain Adaptation
摘要
在区域自适应(domain adaptation)中,最大平均差(maximum mean discrepancy,MMD)作为源域和目标域分布的一种差异度量方法被广泛采用,然而,现有的基于MMD的域自适应方法通常忽略了类先验分布的变化,比如跨域的类权重偏差(class weight bias across domains),对于领域适应来说是普遍存在的,作者认为这可能是由于样本选择标准和应用程序场景的变化造成的。为了解决这个问题,作者提出了加权MMD(weighted MMD)方法。
作者利用源域和目标域上的类先验概率,将类特定的辅助权重引入到原始MMD中,但是目标域中的类标签不可用,该文章的加权MMD模型是通过在源域中为每个类引入一个辅助权重来定义的,并通过分配伪标签、估计辅助权重和更新模型参数之间的交替,提出了一种分类CEM算法,实验表明WMMD优于传统的MMD。文章地址。
类先验分布(Class prior distributions)
比如在MNIST、USPS和SVHN数据集中,各种类的分布是不一致的。
WMMD方法
类权重偏置下MMD和WMMD正则化项(regularizer)的最小化理论结果和过程分别如(a)和(b)所示。最小化mmd会保留源域中的类权值(class weights),从而错误地估计目标样本,如(a)黄色虚线所示。与之相反,所提出的加权MMD通过对源数据进行第一次加权来消除类偏差的影响,如(b)黄色虚线所示。
域适应(Domain adaptation)中的MMD方法
最大平均偏差(MMD)是一种在两个数据集上比较分布的有效的非参数度量方法,给定两个分布s和t,通过φ(·)将数据映射到再生核希尔伯特空间(reproducing kernel Hilbert space -RKHS),s和t的MMD值可以定义为:
E
x
s
−
s
[
∗
]
E_{x^s-s}[*]
Exs−s[∗] 为s的期望,
∣
∣
φ
∣
∣
H
≤
1
||φ||_{H ≤ 1}
∣∣φ∣∣H≤1 为RKHS的单位球中所定义一组函数。
特征核与feature map φ 有关,定义于m PSD核 { k u } \{k_u\} {ku}的凸组合
文章中使用的是高斯核函数:
不论是MMD还是WMMD都是采用的多核函数的方法,具体是采用多个高斯核
{
k
u
}
u
=
1
m
\{k_u\}^m_{u=1}
{ku}u=1m,
γ
u
=
[
2
−
8
γ
,
2
8
γ
]
\gamma_u=[2^{-8} \gamma,2^8 \gamma ]
γu=[2−8γ,28γ],
γ
\gamma
γ取值步长为
2
1
/
2
2^{1/2}
21/2
WMMD方法
类条件分布(class conditional distributions):
w
c
u
w^u_c
wcu为source或target domain的类先验概率,C为分类数量,p为概率密度函数。
文章提出用
p
s
,
a
x
s
p_{s,a}{x^s}
ps,axs来比较source 和 target domain 之间的差异。
其中,
a
c
=
w
c
t
/
w
c
s
a_c=w^t_c/w^s_c
ac=wct/wcs。
WMMD通过下面的方程计算,
h
l
,
w
(
z
i
)
h_{l,w}(z_i)
hl,w(zi)为一个四元组,表示为:
WMMD的目标函数定义为:
前面两项是source domain和target domain 的 soft-max loss 项,最后一项为
l
1
−
l
2
l_1 - l_2
l1−l2层的WMMD正则项。
模型
DAN(domain adaptation network)框架
该文章是在文章《Learning Transferable Features with Deep Adaptation Networks》工作的基础上改进的,其模型也类似于后者的DAN模型,所给示例基于AlexNet,因为是迁移学习,原文实验中采用的模型包含从ImageNet 2012的预训练AlexNet模型迁移过来的,其中conv1-conv3直接使用保持不变(frozen),即lr=0,conv4-conv5略微调整(fine-tune)以适应域差异来迁移特征(domain-bias),lr=0.1,fc6-fc8采用具体MMD/WMMD方法处理来完成对task-specific转化。这样做的原因在于,随着网络深度的加深,特征在网络中由一般特征向特定特征过渡,随着区域差异的增大,深度特征在更高层次上的可移植性显著下降,fc层更是向源任务目标变得特化,对于AlexNet学习类似的数据来说,前三层学习到的可能更多是通用特征(general feature),后面学习到的更多是特定的特征(specific feature)(《how transferable are features in deep neural networks?》)。
训练过程
- 把source和target domain data 无差别通过conv1-conv5进行训练
- 在fc层,souce 和 target domain data 分开训练,通过MMD/WMMD方法计算两个domain之间的差异(距离),通过损失函数最小化更新参数
- 如果是WMMD,还包括CEM算法步骤
target domain缺少label信息如何计算loss
对于domain adaptation问题来说,其中一个难点就是target domain的label信息没有或很少,因此像传统方法一样以target data去对以训练好地模型进行调参是不可行的或导致过拟合。所以提出了DAN/WDAN模型来解决这个问题。
WDAN做法是采用CEM算法计算伪标签来计算loss,下面会给出CEM算法详细步骤
CEM 算法
因为问题的特殊性,target domain 的 label 在训练的过程中是不知道的,作者提出了一个网络weighted domain
adaptation network (WDAN) model,这个model是一个加入了WMMD项的CNN半监督逻辑回归模型的重要拓展。模型的训练算法拓展于EM算法,取名CEM算法。
算法步骤
给出模型参数W,对于每个 x j t x^t_j xjt,先根据softmax分类器输出估计类的后验概率,伪标签 y j y_j yj 由 x j t x^t_j xjt最大后验概率赋值,辅助权值α由伪标签估计得到,得到 y j j = 1 N {y_j}^N_{j=1} yjj=1N和α,传统的反向传播算法就可以应用于更新权值W。
- E-step:估计target domain的后验概率。把target domain 数据放到模型里面里面,拿到其输出。
- C-step:赋值伪标签并估计辅助权重α,取E-step里面最大的下标,也就是估计为哪一类,然后根据下面公式得到α。
- M-step:更新权重W
实验
模型选择和正则项位置
作者选择了LeNet,AlexNet,GoogleNet和VGG-16作为基础模型,具体来说,WMMD正则项加在AlexNet最后三个全连接层,GoggleNet最后一个inception和全连接层,LeNet最后一个卷积层。github源代码。
实验结果
值得一提的是,作者为了验证其观点(现有的基于MMD的域自适应方法通常忽略了类先验分布的变化,导致网络性能下降),做了一个小实验,挑选只包含两个类的source domain和target domian,通过改变两个类的概率分布分次实验,发现WDAN(Weighted Domain Adapation Network)的方法是最鲁棒的,也进一步验证了作者的猜想。
同时,作者还对WDAN和DAN学到的特征进行可视化,可以看到WDAN的分类效果更好。