联邦学习——基于知识蒸馏的多源域适应

《KD3A: Unsupervised Multi-Source Decentralized Domain Adaptation via Knowledge Distillation》是一篇被ICML 2021接收的论文,论文作者来自浙江大学CAD&CG国家重点实验室。这篇文章提出一个对来自多个源域的模型进行知识蒸馏来构建可迁移的共识知识,并可应用于联邦学习的无监督多源域适应方法。

简介

传统的无监督多源域适应UMDA(Unsupervised Multi-source Domain Adaptation)假设所有源域数据都可以直接访问。然而,隐私保护政策要求所有数据和计算都必须在本地进行,这对域适应方法提出了三个挑战:
首先,最小化域间距离需获取源域和目标域的数据并进行成对计算,而在隐私保护要求下,源域数据本地存储、不可访问
其次**,通信成本和隐私安全限制了现有域适应方法的应用。
最后,由于无法鉴别源域数据质量,更易出现
不相关或恶意源域**,从而导致负迁移。
为解决上述问题,该文章提出一种满足隐私保护要求的去中心化无监督域适应范式,称为基于知识蒸馏的去中心化域适应KD3A(Knowledge Distillation based Decentralized Domain Adaptation),通过对来自多个源域的模型进行知识蒸馏来构建可迁移的共识知识。大量实验表明,KD3A显著优于其他域适应方法。此外,与其他去中心化的域适应方法相比,KD3A 对负迁移具有鲁棒性,并可将通信成本降低100倍。

UMDA的相关研究即不足

UMDA通过减小源域 D S D_S DS和目标域 D T D_T DT之间的H-divergence来建立可迁移特征。maximum mean discrepancy (MMD)和对抗训练是实现H-divergence的优化策略的两种主流范式。此外,知识蒸馏也被用于模型上的知识迁移。
基于MMD的各种方法都需要对来自源域和目标域的数据进行成对计算,这在去中心化约束下是不允许的。
对抗训练要求完成每个batch的训练后,每个源域都要与目标域交换和更新模型参数,这会产生巨大的通信代价。
由于不相关和恶意源域,应用于域适应中的传统的知识蒸馏策略可能无法获得正确的知识。

KD3A技术框架

该文章提出了一种解决上述问题的方法——Knowledge Distillation based Decentralized Domain Adaptation (KD3A),旨在通过对来自不同源领域的模型进行知识蒸馏来实现去中心化的域适应
KD3A通过串联使用三个组件组成。
首先提出一种多源知识提取方法——Knowledge Vote以获得高质量的域间共识知识
第二,根据不同域的共识知识的质量,设计了一种动态加权策略Consensus Focus识别恶意和无关的源域
最后给出了一种对H-divergence的去中心化优化策略——BatchNorm MMD。并从理论角度分析了KD3A的去中心化泛化边界
框架中提到的共识知识是什么,共识知识的质量怎么定义,具体权重要怎么给?别急,继续看下面的技术细节。

KD3A技术细节

D S D_S DS D T D_T DT分别代表源域和目标域,在UMDA中,设置k个源域{ D S k D_S^k DSk} k = 1 k _{k=1}^k k=1k,每个源域包括 N K N_K NK个已标注样本: D S k D_S^k DSk:={ x i k x_i^k xik, y i k y_i^k yik} i = 1 N k _{i=1}^{N_k} i=1Nk。目标域 D T D_T DT具有 N T N_T NT个无标注样本: D T D_T DT:={ x i T x_i^T xiT} i = 1 N T _{i=1}^{N_T} i=1NT
UMDA的目标是学习一个模型h,该模型能最小化 D T D_T DT的任务风险 ε D T ε_{D_T} εDT ε D T ε_{D_T} εDT定义如下:
在这里插入图片描述

一、利用共识知识扩展源域

考虑一个C-way分类任务,并假设目标域与源域的任务相同。
假设k个源域上的k个完全训练模型为{ h S k h_S^k hSk} k = 1 K _{k=1}^K k=1K q S k ( X ) q_S^k(X) qSk(X)代表每个类的置信度,并且以具有最高置信度的类作为标签,
h S k ( X ) h_S^k(X) hSk(X) = a r g c m a x [ q S k ( X ) ] c arg_cmax[q_S^k(X)]_c argcmax[qSk(X)]c
如图a所示,UMDA中的知识蒸馏可以分为两步。
首先,对于每个目标域的数据 X i T X_i^T XiT,获取各源域中模型的推断(inferences)。
然后,利用集成方法获得各源模型的共识知识 P i P_i Pi,即 P i P_i Pi = 1/K Σ k = 1 K Σ_{k=1}^K Σk=1K q S k ( X i T ) q_S^k(X_i^T) qSk(XiT)
在这里插入图片描述
为了将共识知识应用于源域,我们根据每个目标域数据 X i T X_i^T XiT和共识知识 P i P_i Pi定义了一个扩展源域 D S K + 1 D_S^{K+1} DSK+1 D S K + 1 D_S^{K+1} DSK+1 = { ( X i T , P i ) (X_i^T, P_i) (XiT,Pi)} i = 1 N T _{i=1}^{N_T} i=1NT
D S K + 1 D_S^{K+1} DSK+1的任务风险定义为:
在这里插入图片描述
在这一新源域上,可以以知识蒸馏的损失函数来定义源模型 h S k + 1 h_S^{k+1} hSk+1的训练,定义如(3)。
在这里插入图片描述

二、Knowledge Vote : Producing Good Consensus

若如上一小节所述,只是简单地用集成方法来获取各源模型的共识知识 P i P_i Pi = 1/K Σ k = 1 K Σ_{k=1}^K Σk=1K q S k ( X i T ) q_S^k(X_i^T) qSk(XiT),可能结果并不如人意,因为源域中可能存在一些无关的或者恶意源域。因此,文章提出Knowledge Vote以获得高质量的共识知识。
Knowledge Vote的主要思想是,如果某个共识知识被更多高置信度的源域支持(例如> 0.9),那么它就更有可能是真正的标签。如图b所示,Knowledge Vote包括三个步骤:
1、confidence gate:过滤”不自信的模型”(如 q s 4 q_s^4 qs4,对于一个二分类任务, q s 4 q_s^4 qs4给出stairs的概率是0.4,floors的概率是0.6,这是一个很模糊的结果,由于他不敢给floors或stairs更高的置信度,所以说他是不自信的)
2、consensus class vote:对剩下的模型,进一步过滤那些预测结果与大众相悖的模型(如 q s 3 q_s^3 qs3,虽然很自信的给了floors很高的置信度,但和大众相悖,也就是同共识知识相悖。)
3、mean ensemble:经过前两轮过滤,我们将剩下的模型进行平均聚合,改用这些模型的个数 n P i n_{P_i} nPi作为知识蒸馏损失函数(式3)的权重,对于那些在confidence gate中过滤掉的模型也是进行平均聚合,但只给一个很小的权重 n P n_P nP = 0.001。参数调整后的知识蒸馏损失函数如5所示。
在这里插入图片描述
在这里插入图片描述
与其他集成策略相比, Knowledge Vote使模型学习到高质量的共识知识,因为我们为那些具有高置信度和许多支持域的类分配了更高的权重。

三、Consensus Focus : Against Negative Transfer

Consensus Focus的主要思想是为那些提供高质量共识知识的域分配高权重,并惩罚那些提供不良共识知识的域。我们首先导出共识知识质量的定义,然后计算每个源域对共识知识质量的贡献。

1、共识知识的质量

原理:如果一个共识类被更多源域赋予更高的置信度,那就更可能是一个true label,意味着共识知识的质量越高。
因此,设源域的集合为S={ D S k D_S^k DSk} k = 1 K _{k=1}^K k=1K,S’为S的子集,对于每个目标域数据 X i T X_i^T XiT和对应的S’中的共识知识( P i ( S ′ ) , n P i ( S ′ ) P_i(S'), n_{P_i}(S') Pi(S),nPi(S)),定义共识知识质量为 n P i ( S ′ ) ⋅ m a x P i ( S ′ ) n_{P_i}(S')·maxP_i(S') nPi(S)maxPi(S)(更多源域赋予更高的置信度),共识知识质量的总和Q定义为
在这里插入图片描述

2、源域的贡献度

根据共识知识的质量定义(7)可以推导出consensus focus (CF) value,来量化每个源域的贡献度,贡献度越高,CF值越高。公式如(8)所示
在这里插入图片描述

3、每个源域的权重分配

由于我们在Knowledge Vote环节给源域集合K引入了一个扩展源域 D S K + 1 D_S^{K+1} DSK+1,需要通过两个步骤来为每个源域分配权重
首先我们根据数据量把权重 a K + 1 a_{K+1} aK+1 = N T / ( Σ k = 1 K N K + N T ) N_T/(Σ_{k=1}^KN_K + N_T) NT/(Σk=1KNK+NT)分配给 D S K + 1 D_S^{K+1} DSK+1
然后使用CF值重新加权每个原先的源域,如下所示
在这里插入图片描述
Consensus Focus有两个优点。首先,计算 a C F a^CF aCF不需要访问原始数据。第二,通过Consensus Focus获得的 a C F a^CF aCF是基于共识知识的质量,而共识知识的质量的计算涉及数据和标签信息,可以识别恶意域。

四、BatchNorm MMD : Decentralized Optimization Strategy of H−divergence

为了获得更好的UMDA性能,我们需要最小化源域和目标域之间的H-divergence。文章利用深度学习中的Batch Norm层所记录的均值和方差参数,提出了BatchNorm MMD,实现在不访问数据的情况下优化H - divergence。
BatchNorm MMD分两步执行H−divergence的去中心化的优化策略。
首先,对于特征π,假设模型上具有L个Batch Norm层,我们从不同源域的模型上获得{ ( E ( π l k ) , V a r ( π l k ) ) (E(π_l^k), Var(π_l^k)) (E(πlk),Var(πlk))} l = 1 L _{l=1}^L l=1L
然后对于每个batch∈ D T D_T DT,我们训练模型 h T h_T hT通过损失函数(12)来优化域适应目标(10)。
在这里插入图片描述
在这里插入图片描述
其中( π 1 T π^T_1 π1T,…, π L T π^T_L πLT)是目标模型 h T h_T hT从对应于输入 X T X_T XT的Batch Norm层得到的特征。在训练过程中,我们使用每个minibatch的均值µ来估计期望E。另外,优化(12)需要遍历所有Batchnorm层,非常耗时。因此,文章也在附录中给出了一个计算高效的优化方式,该方法如下所示:
在这里插入图片描述

五、Generalization Bound For KD3A

结合原始边界和知识蒸馏边界进一步推导出KD3A的泛化边界。
记H为模型空间,{ ε D S k ( h ) ε_{D_S^k}(h) εDSk(h)} k = 1 K _{k=1}^K k=1K ε D T ( h ) ε_{D_T}(h) εDT(h)分别为源域{ D S K D_S^K DSK} k = 1 K _{k=1}^K k=1K和目标域 D T D_T DT上的任务风险误差,对于由多个源域模型集成而来的全局模型 h T h_T hT∈H,满足 h T h_T hT = Σ k = 1 K + 1 a k h S k Σ_{k=1}^{K+1}a_kh_S^k Σk=1K+1akhSk,有如下泛化误差边界
在这里插入图片描述
KD3A界(13)的泛化性能取决于共识知识的质量,如proposition 2所示
在这里插入图片描述
proposition 2提出了两个更严格的约束条件:(1)对于那些H−divergence小、最优任务风险 λ S k λ^k_S λSk较低的源域,模型应利用它们的优势提供更好的共识知识,即让任务风险 ε D s K + 1 ε_{D^{K+1}_s} εDsK+1 ε D T ε_{D_T} εDT足够接近。(2)对于H−divergence和λ较高的不相关和恶意源域,模型应该过滤掉它们的知识,即让任务风险 ε D s K + 1 ε_{D^{K+1}_s} εDsK+1远离那些不良源域。
KD3A通过Knowledge Vote和Consensus Focus启发式地实现了上述两个条件:(1)对于好的源域,KD3A提供了更好的共识知识和Knowledge Vote,使 ε D s K + 1 ε_{D^{K+1}_s} εDsK+1 ε D T ε_{D_T} εDT更加接近。 (2)对于不良源域,KD3A用Consensus Focus过滤掉它们的知识,使得 ε D s K + 1 ε_{D^{K+1}_s} εDsK+1远离不良源域。文章中还进行了充分的实验,证明KD3A相比与其他UMDA方法,实现了更严格的界和更好的性能。

六、The Algorithm of KD3A

首先,我们获得一个额外的源域 D s K + 1 D^{K+1}_s DsK+1,并通过Knowledge Vote训练源模型 h s K + 1 h^{K+1}_s hsK+1
然后,我们通过Consensus Focus将k+1个源模型进行聚合,得到目标模型,即 h T h_T hT := Σ k = 1 K + 1 a k h s k Σ_{k=1}^{K+1}a_kh_s^k Σk=1K+1akhsk
最后,我们通过Batchnorm MMD最小化目标模型的H−divergence。
KD3A的去中心化训练过程如算法1所示。Confidence gate是KD3A中唯一的超参数,应慎重对待。如果Confidence gate过大,几乎所有目标域的数据都会被消除,Knowledge Vote损失函数将不起作用。如果Confidence gate太小,那么共识知识的质量就会降低。因此,我们将参数从小(eg: 0.8)到达(eg: 0.95)逐渐增大。
在这里插入图片描述

实验

1、数据集

文章在四个基准数据集上进行了实验:(1) Amazon Review (Ben-David et al., 2006),这是一个情感分析数据集,包含四个源域。(2)Digit-5 (Zhao et al., 2020),这是一个数字分类数据集,包括五个源域。 (3) Office-Caltech10 (Gong et al., 2012),包含来自四个源域的十类图像。(4) DomainNet (Peng et al., 2019),这是最近推出的具有 345 个类和 6 个域的大规模多源域适应基准数据集,如图二所示。鉴于篇幅所限,文章主要汇报DomainNet上的实验结果。
在这里插入图片描述

2、域适应实验

在DomainNet上进行域适应实验。总体而言,KD3A 显著优于所有基线域适应方法,并在剪贴画(Clipart)和素描(Sketch)上达到了和全监督相同的性能。此外,对KD3A的三个模块进行了消融实验用以评估每个模块的贡献。结果表明,Knowledge Vote、Consensus Focus和BatchNorm MMD都能够提高性能,而大部分贡献来自Knowledge Vote,表明 KD3A也可以在那些无法使用 BatchNorm MMD 的任务上表现良好。
在这里插入图片描述

3、KD3A对负迁移的鲁棒性实验

为验证在Consensus Focus的作用下,KD3A 对负迁移具有鲁棒性,文章在DomainNet数据集上人为构建不相关和恶意的源域并进行了模拟实验。其中,选取Quickdraw作为不相关源域,记为IR-qdr,用注毒攻击(Poisoning Attack)构造恶意源域,在好的源域Real中选取​的数据打上错误标签,记为MA-m。对比于两个优秀的加权策略:H-divergence和Info Gain,如Figure 4所示,Consensus Focus可以识别不相关的域并且在恶意域鉴别中为其分配了极低的权重,而其他两个策略无法识别恶意域。如Table 2所示,Consensus Focus的平均准确率也优于其他两个策略。(Domain drop代表的应该是不具有无关或恶意源域的情况下KD3A的准确率)
在这里插入图片描述在这里插入图片描述

4、通信效率与隐私安全实验

为了评估通信效率,文章在不同的通信轮次®的设定下训练 KD3A,并汇报DomainNet上的域适应精度。 如下图所示,KD3A 能够在低通信成本(r = 1)下工作,与联邦对抗域适应(FADA)相比,减少了100倍的通信量。 由于通信成本低,KD3A 对前沿的梯度泄漏攻击具有鲁棒性,这证明了高隐私安全性。
在这里插入图片描述

总结

文章提出了一种有效的方法——KD3A来解决去中心化的UMDA中存在的问题。KD3A的主要思想是在不访问源域数据的情况下,通过知识提炼进行域自适应。在大规模DomainNet上的大量实验表明,我们的KD3A算法优于其他最先进的UMDA算法,并且对负迁移具有很强的鲁棒性。此外,KD3A在通信效率上有很大的优势,对隐私泄露攻击具有很强的鲁棒性。

  • 4
    点赞
  • 13
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

联邦学习小白

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值