最优传输论文(六)Multi-Source Distilling Domain Adaptation

前言

本文提出了一种多源域适应方法,方法分为四步:
1.为每个源域分别训练特征提取器和分类器。
2.学习目标编码器F T i将目标特征映射到源域空间。类似于GAN对抗训练方式,1)训练Di来最大化从F T i编码后的目标特征和源特征的Wasserstein距离,目的是将源域目标域特征进行区分。
2)训练F T i来使D i出错,最小化F T i编码后目标特征和源特征的Wasserstein距离,将二者进行混淆
3.选取Wasserstein距离目标特征近的一半源域样本微调源分类器Ci。
4.对测试的目标样本预测值进行加权聚合。根据目标特征F Ti(xT),利用各自的源分类器得到每个源域的预测C’,根据源和目标特征之间的Wasserstein距离计算权值wi进行加权聚合。

Introduction

本文提出了一种新的多源蒸馏域自适应(MDDA)网络。
MDDA包括四个阶段:
(1)利用每个源的训练数据分别对各自的源分类器进行预训练;
(2)固定每个源域的特征提取器,通过最小化源与目标的经验Wasserstein距离,将目标分别逆向映射到每个源的特征空间中;
(3)选择离目标更近的源训练样本,对源分类器进行微调;
(4)通过相应的源分类器对每个编码后的目标特征进行分类,根据源域的不同权重对分类结果加权,构建目标预测器,域权值对应每个源域与目标之间的差异。
在这里插入图片描述
图1:探究不同源域和目标之间关系的MDDA模型。我们使用判别器D以对抗的方式测量每个源和目标之间的相似性ω。选取离目标较近的样本来提取源分类器C’ . 不同蒸馏源分类器的预测根据域相似度进行聚合,得到目标样本的最终预测。
在这里插入图片描述
在这里插入图片描述
图2:提出的多源蒸馏域自适应(MDDA)网络的框架。虚线矩形和梯形表示固定的网络参数。F、C、D分别为特征提取器(feature extractor)、分类器(classifier)、域鉴别器(domain discriminator)。为了简单起见,我们只考虑第i和第k个源域。本文提出的MDDA由四个阶段组成,从左到右分别为:源分类器预训练、对抗判别适应、源提取和聚合目标预测。

1.Source Classifier Pre-training

为每个标记源域Si预先训练了一个特征提取器Fi和分类器Ci,不同域之间具有未共享的权重。以n类分类任务为例,通过使交叉熵损失最小来优化Fi和Ci:
在这里插入图片描述
其中σ为softmax函数

2.Adversarial Discriminative Adaptation

在预训练阶段之后,我们学习单独的目标编码器F T i将目标特征映射到与源Si相同的空间中。1)对鉴别器Di进行对抗性训练,以最大化从F T i正确分类编码后的目标特征和从预训练的F i正确分类编码后的源特征的Wasserstein距离(Di试图将目标特征和源特征正确区分),而2)F T i试图最大化Di出错的概率,即最小化Wasserstein距离(将源特征目标特征进行混淆),类似于GAN 。
我们假设鉴别器{Di}都是1-Lipschitz,然后我们可以通过** **来优化Di。
在这里插入图片描述
F T i是通过最小化以下公式得到的:
在这里插入图片描述
为了执行Lipschitz约束,我们为每个鉴别器Di的参数添加梯度惩罚:
在这里插入图片描述
其中,x是一个特征集,它不仅包含源和目标特征,而且还包含源和目标特征对之间直线上的随机点。
Di可通过以下公式优化:
在这里插入图片描述

3.Source Distilling

这一步是第三步:根据Wasserstein距离来选取接近目标的源训练样本来微调源分类器。
对于第i个源域中的每个源样本x j i,我们计算每个源样本与目标域之间的Wasserstein距离:
在这里插入图片描述
τ j i反映x j i 到目标域的距离,τ j i值越小,越接近目标域,因此,我们选择每个源域中一半的距离较小的样本。使用这些选定的源数据,我们通过最小化以下公式来微调Ci:
在这里插入图片描述

4.Aggregated Target Prediction

测试阶段,目标是准确地对给定的目标图像xT进行分类。基于阶段2中学习到的目标编码器提取目标图像的特征F Ti (xT),使用蒸馏源分类器得到每个源域的特定预测C ’ i(F T i (xT))。
在这里插入图片描述
这里的关键问题是如何为不同源分类器的预测选择权重ωi。我们根据每个源和目标之间的差异设计了一种新的加权策略,以强调更多相关的源,抑制不相关的源,源域和目标域之间距离越远,这个源域权重越小。我们假设经过阶段2的训练后,每个源Si与目标T之间的估计Wasserstein距离LwdDi服从于一个标准高斯分布N(0,1)。因此,每个域的权重可以由下式计算
在这里插入图片描述

  • 1
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值