FedAvg

Communication-Efficient Learning of Deep Networks from Decentralized Data

Abstract

  • 研究背景:手机等移动设备拥有丰富的数据,适合用于学习模型以改善用户体验,但这些数据通常具有隐私敏感性或数据量大的特点,不适合集中存储在数据中心进行训练。
  • 解决方法:提出联邦学习(Federated Learning),让训练数据分布在移动设备上,通过聚合本地计算的更新来学习共享模型。
  • 主要贡献
    • 确定了从移动设备的分散数据中训练的问题是一个重要的研究方向。
    • 选择了一种可应用于该场景的直接且实用的算法,即基于迭代模型平均的FederatedAveraging算法。
    • 通过广泛的实证评估,证明该方法对不平衡和非独立同分布(non - IID)的数据分布具有鲁棒性,并能大幅减少在分散数据上训练深度网络所需的通信轮数。

Introduction

  • 研究动机:移动设备上的大量数据对训练模型有很大优势,但将这些数据集中存储在数据中心存在隐私风险,因此需要一种不需要集中存储数据的学习技术。
  • 联邦学习的定义:联邦学习是一种由中央服务器协调的、参与设备(客户端)组成的松散联盟来解决学习任务的方法。每个客户端拥有本地训练数据集,仅计算当前全局模型的更新并将其发送给服务器,服务器进行模型平均。
  • 联邦学习的优势
    • 训练与直接访问原始训练数据解耦,降低了隐私和安全风险,因为只需要信任协调训练的服务器,而攻击面仅限于设备,而非设备和云。
  • 联邦优化的关键属性
    • Non - IID:给定客户端上的训练数据通常基于特定用户对移动设备的使用,因此其本地数据集不能代表总体分布。
    • Unbalanced:不同用户对服务或应用的使用程度不同,导致本地训练数据量不同。
    • Massively distributed:参与优化的客户端数量远大于每个客户端的平均示例数量。
    • Limited communication:移动设备经常离线或处于慢速、昂贵的连接状态。
  • 实验设置:采用同步更新方案,每轮通信中选择固定比例的客户端,服务器将当前全局算法状态发送给这些客户端,客户端进行本地计算并发送更新给服务器,服务器应用这些更新并重复该过程。
  • 相关工作:此前的一些研究如McDonald等人和Povey等人的工作只考虑了集群/数据中心设置,且未考虑不平衡和非IID数据;Neverova等人的工作与本文动机相似,但也未考虑这些关键因素;在凸优化设置中,一些算法专注于通信效率,但假设条件在联邦优化设置中不成立;异步分布式SGD已应用于训练神经网络,但在联邦设置中需要的更新数量过多;分布式共识算法不适合通信受限的大规模客户端优化。
  • 算法考虑:基于随机梯度下降(SGD)构建联邦优化算法,考虑了基线FederatedSGD(FedSGD)和改进的FederatedAveraging(FedAvg)算法,FedAvg通过增加每个客户端的计算来减少通信轮数。
Federated Averaging Algorithm

首先看FedSGD算法:选择一部分客户端来计算损失梯度以更新全局模型参数。

在这里插入图片描述
在这里插入图片描述

联邦优化相关解释

  • 对于机器学习问题,目标是最小化形如\(\min_{w \in \mathbb{R}^{d}} f(w)\)的目标函数,其中\(f(w) \stackrel{ def }{=} \frac{1}{n} \sum_{i = 1}^{n} f_{i}(w)\)
    • w:模型的参数,它是一个d维的实数向量,模型通过调整这些参数来最小化损失函数。
    • f(w):表示整个模型的损失函数,它是所有样本损失的平均值。我们的目标是找到一组参数w,使得f(w)尽可能小。
    • f_{i}(w):对于每个样本(x_i, y_i),使用模型参数w进行预测时会产生一个损失,f_{i}(w)就是表示这个样本的损失。通过对所有样本的损失进行求和并取平均,得到了整体的损失函数f(w)
  • 假设数据被划分在K个客户端上,P_k是客户端k上数据点索引的集合,n_k = |P_k|
    • P_kP_k是一个集合,其中包含了分配给客户端k的数据点的索引。例如,如果有一批数据总共有n个样本,按照某种方式将这些样本分配到K个客户端上,那么P_k就表示分配给第k个客户端的样本的索引。
    • n_kn_k表示客户端k上数据点的数量。具体来说,|P_k|表示集合P_k中元素的个数,也就是客户端k拥有的数据点的数量。
  • 目标函数可以重写为f(w) = \sum_{k = 1}^{K} \frac{n_k}{n} F_k(w),其中F_k(w) = \frac{1}{n_k} \sum_{i \in \mathcal{P}_k} f_i(w)
    • F_k(w):它表示客户端k上本地数据的平均损失函数。具体计算方式是先对客户端k上的所有样本的损失f_i(w)进行求和,再除以样本数量n_k,得到客户端k上的平均损失。
    • f(w) = \sum_{k = 1}^{K} \frac{n_k}{n} F_k(w):这个式子是将整体的损失函数f(w)表示为各个客户端上平均损失函数的加权和。权重是每个客户端上的数据点数量n_k与总数据点数量n的比例。这样的表示方式有助于在联邦学习中考虑数据在不同客户端上的分布情况。
  • 如果数据点在客户端上的划分P_k是通过将训练示例均匀随机地分布在客户端上形成的,那么对于固定的客户端k,有E_{P_k}[F_k(w)] = f(w),这里的期望是对分配给客户端k的示例集取期望。
    • E_{P_k}:表示对客户端k上数据分布的期望运算。也就是说,我们考虑所有可能的将数据分配到客户端k的情况,对这些情况下的F_k(w)取平均值。
    • E_{P_k}[F_k(w)] = f(w):在数据均匀随机分布的理想情况下,客户端k上的平均损失函数的期望就等于整体的平均损失函数。但在实际的联邦学习中,数据往往不是这样均匀分布的,可能存在非独立同分布(non - IID)的情况。

现实例子

假设有一家公司想要开发一个图像分类模型,用于识别不同的物体,例如猫、狗、汽车等。该公司有许多移动设备用户,这些用户的设备上有大量的图片数据。

公司决定采用联邦学习的方式来训练模型,将数据保留在用户的设备上,而不是上传到中央服务器。

现在假设有K = 5个客户端(用户的设备),总数据量为n = 1000张图片。这些图片被标记为不同的类别(如猫、狗、汽车等)。

每个客户端k上的数据点索引集合P_k表示该客户端上的图片索引。例如,客户端1可能有P_1 = \{1, 2, 3,..., 200\},表示它有n_1 = 200张图片。

对于每张图片i,使用模型参数w进行预测时会产生一个损失f_i(w)。例如,如果模型对一张猫的图片预测错误,就会产生一定的损失。

客户端k上的平均损失函数F_k(w)是该客户端上所有图片的损失的平均值。例如,对于客户端1F_1(w) = \frac{1}{200} \sum_{i \in P_1} f_i(w)

整体的损失函数f(w)是所有客户端上的平均损失的加权和。假设每个客户端上的图片数量大致相等,那么f(w) = \sum_{k = 1}^{5} \frac{200}{1000} F_k(w) = \frac{1}{5} \sum_{k = 1}^{5} F_k(w)

现在考虑数据的分布情况。如果数据是均匀随机分布在客户端上的,那么每个客户端上的图片类别分布应该比较相似,即E_{P_k}[F_k(w)] = f(w)。例如,每个客户端上都可能有大致相同比例的猫、狗、汽车等图片。

但在实际情况中,数据可能是非独立同分布的。比如,客户端1上的图片主要是猫,客户端2上的图片主要是狗,客户端3上的图片主要是汽车。这样,每个客户端上的F_k(w)就不能很好地代表整体的f(w)

在联邦学习中,服务器会协调各个客户端进行模型训练。每个客户端根据服务器发送的当前全局模型参数,在本地数据上计算更新,并将更新发送给服务器。服务器再根据这些更新来调整全局模型参数。

例如,在一轮训练中,服务器将当前模型参数发送给客户端123。客户端1在本地的200张猫的图片上进行计算,得到一个关于模型参数的更新。客户端2在本地的狗的图片上进行计算,得到另一个更新。客户端3在本地的汽车图片上进行计算,得到第三个更新。服务器收到这些更新后,综合考虑这些非独立同分布的数据情况,对全局模型参数进行调整。

通过这样的方式,联邦学习可以在保护用户数据隐私的前提下,利用分散在各个客户端上的数据来训练一个有效的模型,尽管数据可能是非独立同分布的。

FederatedAveraging(FedAvg)算法:

FedAvg算法通过三个关键参数来控制计算量,即C(每轮执行计算的客户端比例)、E(每个客户端对本地数据集进行训练的次数)和B(客户端更新时使用的本地小批量大小)。当B = ∞且E = 1时,对应于FedSGD。对于具有(nk)个本地示例的客户端.

在这里插入图片描述

符号含义
  • C:客户端分数,控制每轮参与计算的客户端的比例。
  • E:每个客户端对本地数据集进行训练的次数,即本地训练周期数。
  • B:客户端更新时使用的本地小批量大小。
  • u:每个客户端每轮的预期更新次数,计算公式为(u = (E[nk]/B)E = nE/(KB)),其中(n)为总数据量,(k)为客户端索引,(n_k)为客户端(k)的数据点数量,(K)为客户端总数。
  • η:学习率,用于控制模型训练过程中参数更新的步长。
  • Pk:客户端(k)上数据点索引的集合。
  • nk:客户端(k)上数据点的数量,即(n_k = |P_k|)。
  • f(w):整体的损失函数,通常表示为(f(w) = \sum{k = 1}^{K} \frac{n_k}{n} F_k(w)),其中(F_k(w) = \frac{1}{n_k} \sum{i \in \mathcal{P}{k}} f{i}(w))。
  • fi(w):通常表示在示例((x_i, y_i))上使用模型参数(w)进行预测的损失。
  • ℓ(w; b):在批次(b)上的损失函数。
专有名词含义:
  • FedSGD:一种基线算法,全称为Federated Stochastic Gradient Descent(联邦随机梯度下降),在每轮通信中,选择一部分客户端,计算这些客户端所持有的所有数据的损失梯度,用于更新全局模型参数。
  • FedAvg:全称为Federated Averaging(联邦平均),是文章提出的主要算法,通过在客户端进行更多的本地计算(调整E和B)并在服务器上进行模型平均来提高训练效果并减少通信轮数。
  • MNIST:全称Modified National Institute of Standards and Technology,是一个手写数字数据集,常用于图像识别任务的研究。
  • CNN:Convolutional Neural Network(卷积神经网络),是一种深度学习模型,常用于图像识别等任务。
  • LSTM:Long Short - Term Memory(长短期记忆网络),是一种循环神经网络,常用于语言建模等序列数据处理任务。
  • CIFAR - 10:一个包含10个类别的图像数据集,每个类别有6000张图像,常用于图像分类任务的实验。
  • IID:Independent and Identically Distributed(独立同分布),表示数据在客户端上的一种分布情况,在这种情况下,数据被打乱并均匀地分配给客户端。
  • Non - IID:非独立同分布,表示数据在客户端上的另一种分布情况,与IID相反,数据在客户端上的分布不均匀。
  • SGD:Stochastic Gradient Descent(随机梯度下降),是一种常用的优化算法,用于训练深度学习模型。

实验结果 - 增加并行性

在这里插入图片描述
在这里插入图片描述

  • 讲解内容
    • 首先,我们来看这个实验的目的,是通过改变客户端分数C来控制多客户端并行性,探究其对MNIST模型的影响。
    • 图2展示了MNIST CNN在IID(独立同分布)和pathological non - IID(病理非IID)数据分布下,以及Shakespeare LSTM在IID和by Play & Role non - IID数据分布下,测试集准确率与通信轮数的关系。从图中可以看出,随着通信轮数的增加,测试集准确率逐渐提高。在MNIST CNN的非IID情况下,通过调整C和B的组合,可以明显改善模型的性能。
    • 再看表1,它详细展示了不同C和B组合下,MNIST 2NN和CNN模型达到特定测试集准确率所需的通信轮数以及相对于C = 0基线的速度提升。例如,对于MNIST 2NN在IID情况下,当C = 0.1且B = 10时,达到测试集准确率97%需要87轮通信,相对于C = 0时的1474轮,速度提升了3.6倍。
    • 作者进行这样的实验和分析,是为了找到最优的参数设置,以提高联邦学习的效率。
    • 而FedAvg的好处在这里也有所体现,通过合理设置参数,它能够在不同数据分布情况下,有效地提高模型的准确率,并且减少通信轮数,提高训练效率。

实验结果 - 增加每个客户端计算

在这里插入图片描述

  • 讲解内容
    -具体量化了FedAvg与FedSGD达到目标准确率所需的通信轮数对比。例如,对于MNIST CNN在达到99%准确率的情况下,FedSGD需要626轮,而FedAvg在某些参数设置下(如E = 5,B = 8)只需要179轮,速度提升了3.5倍。
    • 作者这样做是为了验证FedAvg算法通过增加客户端计算来减少通信轮数的有效性。
    • FedAvg的优势在于它能够充分利用客户端的计算资源,通过多次本地更新和模型平均,提高了模型的训练效果,同时减少了与服务器的通信次数,降低了通信成本。

实验结果 - 过优化讨论

在这里插入图片描述

  • 讲解内容
    • 图3展示了在训练初期对Shakespeare LSTM问题使用大E(本地训练周期数)的影响。可以看到,当E很大时,FedAvg可能会出现平稳或发散的情况。
    • 这说明对于一些模型,在训练过程中需要合理控制E的大小,避免过优化。
    • 作者进行这个实验是为了研究模型在不同训练条件下的稳定性和收敛性。
    • 这也从侧面反映了FedAvg算法在应用时需要注意参数的调整,以确保模型能够有效地训练。

实验结果 - CIFAR - 10实验

在这里插入图片描述
在这里插入图片描述

  • 讲解内容
    • 我们在CIFAR - 10数据集上进行实验,进一步验证FedAvg的有效性。图4展示了FedSGD和FedAvg在CIFAR10实验中的学习率曲线,可以看出FedAvg在不同学习率下的性能表现相对稳定,而FedSGD的性能则受到学习率的影响较大。
    • 表3给出了基线SGD、FedSGD和FedAvg达到不同目标测试集准确率所需的通信轮数及速度提升。可以明显看到,FedAvg相比其他两种算法,能够在显著减少通信轮数的情况下达到相似或更好的测试准确率。
    • 作者进行这个实验是为了证明FedAvg在不同数据集上的通用性和优越性。
    • FedAvg在CIFAR - 10数据集上的出色表现,再次证明了它是一种高效的联邦学习算法,能够适应不同类型的数据和任务。

实验结果 - 大规模LSTM实验

在这里插入图片描述

  • 讲解内容
    • 这个实验是在大规模下进行的下一个单词预测任务,以展示FedAvg在现实问题中的有效性。图5展示了FedAvg和FedSGD在大规模语言建模任务中的学习曲线。可以看到,FedSGD需要820轮才能达到10.5%的准确率,而FedAvg在η = 9.0时仅需35通信轮数(23×更少)就能达到同样的准确率,并且FedAvg的测试准确率方差更低。
    • 作者进行这个实验是为了验证FedAvg在大规模实际应用中的性能。
    • 这表明FedAvg在处理大规模数据和复杂任务时,仍然能够快速收敛并提供更稳定的结果,进一步体现了它的优势。

vg在η = 9.0时仅需35通信轮数(23×更少)就能达到同样的准确率,并且FedAvg的测试准确率方差更低。
- 作者进行这个实验是为了验证FedAvg在大规模实际应用中的性能。
- 这表明FedAvg在处理大规模数据和复杂任务时,仍然能够快速收敛并提供更稳定的结果,进一步体现了它的优势。

总之,这些实验结果充分证明了FedAvg算法在联邦学习中的有效性和优越性,它能够通过合理的参数设置和计算优化,减少通信成本,提高模型训练效率和准确率,适用于多种数据集和模型架构。

  • 7
    点赞
  • 7
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
FedAvg pytorch是一个用于联邦学习的算法。它通过对参与者的本地模型进行加权平均来实现全局模型的更新。这个算法的实现非常简单,它首先对每个参与者的模型参数进行平均,然后将平均参数作为全局模型的更新。具体代码如下所示: def FedAvg(w): w_avg = copy.deepcopy(w += w[i][k w_avg[k = torch.true_divide(w_avg[k], len(w)) return w_avg 其中,w是一个包含参与者模型参数的列表。算法遍历每个参数的键值对,将所有参与者的对应参数加和,并将结果除以参与者的数量,得到平均参数作为全局模型的更新。这样,通过不同参与者的贡献,全局模型可以得到更新并获得更好的性能。<span class="em">1</span><span class="em">2</span><span class="em">3</span> #### 引用[.reference_title] - *1* [联邦学习算法FedAvg实现(PyTorch)](https://blog.csdn.net/Joker_1024/article/details/116377064)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v93^chatsearchT3_2"}}] [.reference_item style="max-width: 33.333333333333336%"] - *2* [联邦元学习算法Per-FedAvg的PyTorch实现](https://blog.csdn.net/Cyril_KI/article/details/123389721)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v93^chatsearchT3_2"}}] [.reference_item style="max-width: 33.333333333333336%"] - *3* [PyTorch 实现联邦学习FedAvg (详解)](https://blog.csdn.net/qq_36018871/article/details/121361027)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v93^chatsearchT3_2"}}] [.reference_item style="max-width: 33.333333333333336%"] [ .reference_list ]
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值