introduce
- 本文研究的范围仅限于UDA(unsupervised domain adaptation)
- 作者认为使用MMD(maximum mean discrepancy)来衡量source domain和target domain之间的差异不够准确,这是因为没有考虑class prior distributions(类的先验分布,就是某个类在整个domain中所占的比重),为了解决这个问题,作者提出了一个叫做weighted MMD(WMMD)的模型。(Despite the great success
achieved, existing ones(MMD-based methods) generally ignore the changes of
class prior distributions, dubbed by class weight bias.) - 对基于MMD的域适应方法来说,对class weight bias(各个类中的样本数所占的比重应该就是class weight)的忽略可能导致性能的下降(For MMD-based methods, the ignorance of class weight
bias can deteriorate the domain adaptation performance) - 如下图:
- MMD的限制在于当source domain和taregt domian中的class weight不同(或者如图中所示,更严重地target domain缺少source domain中的类)时,使用MMD会导致错误的分类(MMD会使得target domain的class weight强行与source domain一致)。
- 然而问题是,target domain是没有label的,所以target domain的class weight是未知的
- 因此作者首先引入了class-specific auxiliary
weights(类特定辅助权重?)来对source domain进行reweight,使得source domain的class weight和target domain的完全一致。 - 通过最小化weighted MMD(WMMD)的目标函数来共同优化auxiliary
weights的估计量和模型参数学习。 - 作者使用一个叫做 classification EM (CEM)的方案来估计他。
- 在E步骤和C步骤中,计算类后验概率(target domain的class weight的后验概率),将伪标签(pesudo label)分配给target domain的样本,并估计auxiliary
weights。 - 在M步骤当中,通过最小化目标函数的损失来更新参数(普通的机器学习训练过程)。
maximum mean discrepancy
- (MMD基础理论部分,数学用语很多,不想翻译了,我就直接贴截图了)
Weighted Maximum Mean Discrepancy
- ps(xs) 、 pt(xt) :source domain和target domain的概率分布密度
- 以上二者都可以用类的条件分布的混合来表示:
- 其中 wsc=ps(ys=c) 和 wsc=ps(yt=c) 就是前文所提到的class prior probility(class weight)。
- MMD比较的是 ps(xs) 和 pt(xt) ,也就是概率密度,但作者认为比较source domain和target domain的条件概率密度 ps(xs|ys=c) 和 pt(xt|yt=c) 更为有效(这个是判别式模型(discrinimative model)学习的目标:类的后验概率)
- 作者建议利用reference source distribution ps,α(xs) 来计算source domain和target domain之间的差异(discrepancy)
要求 ps,α(xs) 和target domain有一样的class weight( wsc=ps(ys=c) )并且保留source domain的条件概率密度( ps(xs|ys=c) ),所以:
weight MMD的基本形式:(利用 ps,α(xs) 和 pt(xt) 来计算):
- weight MMD的线性复杂度近似(为了速度和SGD,详细的理论见上面MMD后半部分)
- weight MMD的线性复杂度近似(为了速度和SGD,详细的理论见上面MMD后半部分)
Weighted Domain Adaptation Network
- 作者认为WMMD正则化层需要加载CNN的高层,因为dataset bias会在高层增加:
- WDAN(Weighted Domain Adaptation Network)模型:
- 优化WDAN的过程:
- E-step:估计target domain的类
{xtj}Nj=1
的后验概率(class posterior probility)
- C-step:基于E-step中计算出的最大的class posterior probility,将伪标签(pseudo-label)
{yˆNj=1}
赋给每个
xtj
,并且估计辅助权重(auxiliary weights)
α
- M-step:在给定的
α
和
{yˆNj=1}
下更新模型参数
W
<script type="math/tex" id="MathJax-Element-21">W</script>:
- E-step:估计target domain的类
{xtj}Nj=1
的后验概率(class posterior probility)