IEEE ICIP 2019 | 更快更好的联邦学习:一种特征融合方法

前言

在这里插入图片描述
  题目: 更快更好的联邦学习:一种特征融合方法
  会议: IEEE ICIP 2019
  论文地址:https://ieeexplore.ieee.org/abstract/document/8803001

  本文将解读清华大学孙立峰教授团队在2019 IEEE International Conference on Image Processing (ICIP)上发表的论文《Towards Faster and Better Federated Learning: A Feature Fusion Approach》。该论文提出了一种特征融合方法来减少联邦学习中通讯的成本,并提升了模型性能:通过聚合来自本地和全局模型的特征,以更少的通信成本实现了更高的精度。此外,特征融合模块为新来的客户端提供更好的初始化,从而加快收敛过程。

Abstract

  联邦学习能够在由大量现代智能设备(如智能手机和物联网设备)组成的分布式网络上进行模型训练。然而,FedAvg算法通常需要很大的通信成本,并且性能也是一个很大的挑战,特别是当本地数据以非IID方式分布时。

  因此,本文提出了一种特殊的特征融合机制来解决上述问题:通过聚合来自本地和全局模型的特征,以更少的通信成本实现了更高的精度。此外,特征融合模块为新来的客户端提供更好的初始化,从而加快收敛过程。

1.Introduction

  为了充分利用设备上的数据,传统的机器学习策略需要从客户端收集数据,然后在服务器上集中训练模型,然后将模型分发给客户端,这给通信网络带来了沉重的负担并且暴露于高隐私风险(所有客户端需要暴露自己的数据)。

  2016年,谷歌提出了联邦学习(Federated Learning)的概念,并首次提出了FedAvg算法,它使用本地数据对客户端执行分布式培训,并将这些模型汇总到中央服务器中以避免数据共享。 通过这种方式,减轻了隐私暴露问题。然而,进一步的研究指出,与其他因素相比,通信成本仍然是FL的主要制约因素,例如计算成本,如果模型接受非IID数据训练,则FedAvg的准确性将显着下降。

  在本文中,提出了一种新的具有特征融合机制(FedFusion)的FL算法来解决上述问题。通过引入特征融合模块,在特征提取阶段之后聚合来自局部和全局模型的特征,而几乎没有额外的计算成本。这些模块使每个客户端的训练过程更加高效,并且更有针对性地处理非IID数据,因为每个客户端将为自己学习最合适的特征融合模块。

  本文贡献:

  • 首次将特征融合机制引入联邦学习。
  • 所提出的特征融合模块以高效和个性化的方式聚合来自本地和全局模型的特征。
  • 实验表明本文所提出的方法在精度和泛化能力方面均优于baseline,并且将通信轮数减少了60%以上。

2.Related Work

  考虑到通信成本是限制FL的主要因素,目前已经有一些学者做了相关的研究工作。比如Konecny等人在客户端到服务器通信的背景下提出了结构化和草图更新;Yao等人对设备上的培训程序引入了额外的限制,旨在拟合本地数据的同时整合来自其他客户的更多知识;Caldas等人提出federated dropout来训练客户端的子集,并将有损压缩扩展到服务器到客户端的通信。

3.Methods

  在本节中,首先介绍所提出的特征融合模块,然后给出具有特征融合机制(FedFusion)的联邦学习算法。

3.1 Feature Fusion Modules

  如下图所示:
在这里插入图片描述
  其中蓝色的部分表示local模型提取的两通道特征,灰色部分表示global模型提取到的两通道特征。图1给出了三种特征融合方式:Conv, Multi和Single。特征的提取在CNN中可以理解为经过卷积和池化操作后得到的图片信息。
在这里插入图片描述
  每一个输入的图像 x x x都会分别被局部特征提取器 E l E_l El和全局特征提取器 E g Eg Eg映射到 R C × H × W R^{C\times H\times W} RC×H×W

  随后,特征融合算子 F F F将两个特征提取器提取到的特征进行融合: F ( E l ( x ) , E g ( x ) ) F(E_l(x),E_g(x)) F(El(x),Eg(x)),两个特征融合后被映射到 R C × H × W R^{C\times H\times W} RC×H×W

3.1.1 Conv operator

在这里插入图片描述
F c o n v ( E l ( x ) , E g ( x ) ) = W c o n v ( E g ( x ) ∥ E l ( x ) ) F_{c o n v}\left(E_{l}(x), E_{g}(x)\right)=W_{c o n v}\left(E_{g}(x) \| E_{l}(x)\right) Fconv(El(x),Eg(x))=Wconv(Eg(x)El(x))
  其中 W c o n v W_{c o n v} Wconv表示shape为 2 C × C 2C\times C 2C×C可学习的权重矩阵。具体操作就是将global特征和local特征进行concat(||)后进行卷积操作。

  关于卷积中通道C、高度H以及宽度W的解释可见:DL入门(1):卷积神经网络(CNN)

3.1.2 Multi operator

在这里插入图片描述
  Multi算子:用一个 λ \lambda λ权重矩阵来对local和global进行一个加权求和。

3.1.3 Single operator

在这里插入图片描述
  Single算子:用一个标量 λ \lambda λ来对local和global进行一个加权求和。

  经过上述操作后,global特征提取器提取到的特征和local特征提取器提取到的特征将融合成为一个新的特征,特征shape为 R C × H × W R^{C\times H\times W} RC×H×W

3.2 Federated Learning with Feature Fusion Mechanism

  本节讲述带有特征融合机制的联邦学习策略!

  本文所提出的FedFusion的典型训练迭代如下图所示:
在这里插入图片描述
  具体来讲:

  客户端在第 i i i轮训练时,将会保留服务器发来的全局的特征提取器 E g E_g Eg,在本地分类器进行迭代更新时,会考虑将 E g E_g Eg E l E_l El进行融合。

  在训练期间, E g E_g Eg被冻结并且引入了3.1中描述的附加特征融合模块。

  在客户端上进行训练后,将与特征融合模块结合的本地模型发送到中央服务器进行模型聚合,这里使用指数移动平均策略来平滑更新。

  算法伪代码:
在这里插入图片描述
  对中央服务器:

  1. 初始化全局参数 G 0 G_0 G0
  2. 对第r轮更新:随机选择m个客户端,然后对这m个客户端做如下操作:将全局参数 G r G_r Gr传递给客户端,算出每一个客户端返回的梯度。最后,根据这些梯度进行指数移动平均,合成新的全局参数 G r + 1 G_{r+1} Gr+1

  对客户端t的第r轮训练来说:

  1. 局部参数 L r t = C ∘ F ∘ E l L_r^t=C\circ F\circ E_l Lrt=CFEl,也就是说局部模型是一个分类器,其中 E l E_l El是本地特征提取器(是需要通过数据来进行学习的,初始时就为全局的特征提取器),提取后经过F特征融合,最后再进行分类。
  2. 对每一个bach内的数据,计算 C ∘ F ( E l ( x ) , E g ( x ) ) C\circ F(E_l(x),E_g(x)) CF(El(x),Eg(x))模型的梯度,然后反向传播更新参数。注意这里的模型,实际上就是本文的创新点所在,本地训练时,模型的特征并不只是简单的本地特征,而是将上一轮的全局模型的特征提取器提取到的特征与本地特征进行融合,融合后再进行训练。
  3. 训练结束后将最新的局部参数传递给服务器,由服务器进行指数移动平均,聚合形成新的全局参数。

4.Experiments

4.1 Experimental Setup

  在实验中使用MNIST和CIFAR10作为基本数据集。

  对于MNIST数字识别任务,使用与FedAvg相同的模型:具有两个5×5卷积层的CNN(第一个具有32个通道,第二个具有64个通道,每个之后是ReLU激活和2×2最大池化),512个节点的完全连接层(ReLU+Random Dropout),softmax输出层。

  对于CIFAR10,使用具有两个5×5卷积层的CNN(均具有64个通道,每个通道之后是ReLU激活和3×3最大池化,stride为2),两个完全连接层(第一个具有384个节点,第二个具有192个节点,每个之后是ReLU+Random Dropout)和最终的softmax输出层。

  数据分割方式:

  1. Artificial Non-IID Partition:每个节点仅包含两种类别。
  2. User Specific Non-IID Partition:每个节点包含相似的类别,但是采用不同的分布,类似multi task学习。
  3. IID分布。

4.2 Artificial Non-IID Partition

在这里插入图片描述

  a和b表述了在人工形成的非IID场景下, FedFusion和FedAvg的收敛图。可以看到,在相同的通讯轮数下,不进行特征融合,也就是FedAvg的表现是最差的,其精度最低。

  (图有些看不清),具体的数据如下表所示:
在这里插入图片描述
  可以看到进行特征融合后(无论哪一种特征融合),模型的精度都有所提升。

  Multi融合方式的效果最好,Conv融合方式次之。

4.3 User Specific Non-IID Partition

  为了模拟用户特定的非IID分区,对每个客户端的MNIST应用不同的排列,这就是之前几项研究中所谓的置换MNIST。

  表2列出了达到某些精度(此处为94%和95%)的通信轮数以及通信轮数相对于FedAvg的减少:
在这里插入图片描述
  从上表可以看出,FedFusion+Conv实现了通讯轮数最大幅度的降低。

  值得注意的是,用户特定的“非IID分区更接近现实的FL场景,因此在这种情况下改进更有意义。

4.4 IID Partition

  如下图所示:
在这里插入图片描述
  在IID场景下,使用Multi和Conv进行融合可以以较低的通信成本实现更高的精度。

  对特征融合算子做出如下简要概括:

  1. Multi算子在局部和全局特征映射之间提供灵活的选择,并且更易于解释。 权重向量 λ \lambda λ考虑了相应通道中全局特征映射的比例。当客户端数据类别存在差距时,Multi算子将学习选择最有用的特征映射。
  2. Conv算子更擅长整合全球和本地模型的知识。 如果客户端的数据具有相似的类别但遵循不同的分布,Conv算子的性能要好得多。
  3. 实验表明,Single算子几乎没有改进,不推荐使用。

5. Conclusion

  联邦学习巨大的通讯成本是一个需要解决的紧急问题。 在本文中,尝试从减少沟通轮次的角度进行一些改进:提出了一种新的具有特征融合模块的FL算法,并在当前较为流行的FL设置中对其进行评估。实验结果表明,该方法具有较高的精度,同时将通信轮次减少了60%以上。

  未来的工作可能包括将目前的算法扩展到更复杂的模型和场景,以及将通信轮次减少策略与其他类型的方法(例如梯度估计和压缩)相结合。

  • 8
    点赞
  • 45
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
### 回答1: 联邦平均算法客户端的训练详细流程如下: 1. 初始化模型参数:客户端首先从服务器获取全局模型参数,并使用这些参数初始化本地模型。 2. 加载本地数据集:客户端加载本地数据集,并将其拆分为多个小批次。 3. 训练本地模型:客户端使用本地数据集训练模型,并更新本地模型参数。 4. 计算本地模型参数更新量:客户端计算本地模型参数与全局模型参数之间的差异,即本地模型参数的更新量。 5. 发送本地模型参数更新量:客户端将本地模型参数更新量发送给服务器。 6. 接收全局模型参数更新:客户端从服务器接收全局模型参数的更新,并使用这些参数更新本地模型。 7. 重复步骤 2-6:客户端在本地数据集上继续训练模型,并将本地模型的更新量发送给服务器,直到达到预定的迭代次数或收敛条件。 8. 完成训练:客户端训练完成后,将本地模型参数发送给服务器,以供全局模型的更新。 ### 回答2: 联邦平均算法客户端的训练详细流程如下: 1. 客户端选择参与训练的模型,并下载所需的客户端代码和数据。 2. 在本地环境中,客户端加载训练所需数据集,并对数据进行预处理,例如数据清洗、标准化等。 3. 客户端将处理后的数据分成多个小批次(mini-batches),以便进行分布式训练。 4. 客户端与服务器建立连接,并发送自己的模型参数和数据批次给服务器。 5. 服务器收到客户端发送的数据后,将客户端的数据合并到全局模型中。 6. 服务器在全局模型上进行模型更新,例如使用梯度下降等优化算法对模型参数进行优化。 7. 服务器将更新后的模型参数发送给所有参与训练的客户端。 8. 客户端接收到服务器发送的全局模型参数后,将其应用于本地的模型中。 9. 客户端使用本地的模型参数对自己的数据进行训练,并得到本地模型的更新参数。 10. 客户端将本地模型的更新参数发送给服务器。 11. 服务器将接收到的本地模型更新参数合并到全局模型中。 12. 重复步骤6至11,直到达到预设的训练轮数或达到训练目标。 整个训练流程中,客户端通过与服务器的交互来完成模型参数的更新和同步。这种分布式训练方式充分利用了客户端的本地数据,保护了客户数据的隐私,同时实现了全局模型的不断优化和改进。每个客户端只需关注本地的数据和模型更新,与其他客户端之间相互独立并行,从而提高了整体的训练效率和模型的准确性。 ### 回答3: 联邦平均算法客户端的训练详细流程包括以下步骤: 1. 数据准备:客户端首先从自己的本地数据集中选择一部分样本作为训练数据。这些数据可能包含标签或特征。在开始训练之前,客户端需要确保数据的质量和完整性。 2. 模型初始化:客户端初始化一个模型,通常使用某种预定义的模型架构。该模型定义了待学习的参数和网络结构。 3. 模型训练:客户端使用本地的训练数据通过优化算法进行模型训练。优化算法的选择取决于具体的问题和算法需求,常见的优化算法包括梯度下降、Adam等。客户端通过最小化损失函数来更新模型的参数,使得模型能够更好地拟合数据。 4. 参数聚合:在一定的训练轮次后,客户端将训练得到的模型参数上传到联邦服务器,与其他客户端的参数进行聚合。这个步骤可以通过加权平均等方法来实现,每个客户端的贡献权重取决于其数据特性和信任度。 5. 更新模型:联邦服务器将聚合后的参数发送回每个客户端。客户端使用这些更新后的全局参数进行下一轮的本地模型训练。 6. 重复迭代:以上过程进行多轮迭代,直到模型达到了预定的收敛条件或训练达到了预定的轮次。 通过联邦平均算法,每个客户端能够在自己的本地数据上进行模型训练,同时通过参数聚合保持了数据的隐私性。这种分布式训练方式能够充分利用各个客户端的数据来提高模型的泛化性能和精度。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Cyril_KI

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值