Communication-Efficient Learning of Deep Networks from Decentralized Data
源码 CCFC 2017 Cited 5849
每日一诗:
《帆影·劳劳亭次别》
明-张居正
劳劳亭次别,无计共君归。
一叶随风去,孤帆挟浪飞。
目穷河鸟乱,望断浦云非。
只在天涯畔,伤心隔翠微
1.Abstraction:
提出了一种实用的基于平均迭代的联邦学习深度神经网络方法——Federated Averaging算法,将每个客户机上的局部随机梯度下降(SGD) 与 执行模型参数平均的服务器结合起来。
证明了它对不平衡和非iid数据分布是稳健的,相较于同步随机梯度下降方法(FedSGD)的通信次数减少10-100倍大大提高了联邦学习得模型效率。
实验部分考虑五种不同的模型架构和四种数据集
2.联邦学习:
2.1 特点:
主要优点是将模型训练与直接访问原始训练数据的需求解耦。对于可以根据每个客户机上的可用数据指定训练目标的应用程序,联邦学习可以通过将攻击面限制在设备上而不是云上,从而显著降低隐私和安全风险。
集中式: 通信耗费少 计算主导
分布式: Fed,通信、计算都耗费很多,用户可能只有在特定时间内(充电 wifi等)参与,故假设每个用户只每天只参与特定轮数
2.2 联邦优化和分布式优化的区别:
2.2.1 非独立同分布
给定客户机上的训练数据通常基于特定用户对移动设备的使用,因此任何特定用户的本地数据集都不能代表总体分布进而导致非独立同分布问题
2.2.2 不平衡
客户端本地训练数据集的大小(size)区别
2.2.3 Massively distributed
希望参与训练的客户端的数量远远大于 每个客户端的平均示例数量
2.2.4 Limited communication
移动设备经常处于离线状态,或者连接速度很慢或连接费用高昂。
3.The FedAvg Algorithm
3.1 解决问题:
由于硬件设备的提升,计算资源冗余 因此可使用额外的计算来减少训练模型所需的通信轮数:
1.增加并行性,我们在每个通信回合之间使用更多独立工作的客户端
2.增加了每个客户机上的计算量,每个客户机在每个通信轮之间执行更复杂的计算,而不是执行像梯度计算这样的简单计算。
本文在并行数量最低限制的基础上探究增加客户机的计算量——FedAvg使用相对较少的沟通来训练高质量的模型
3.2 模型对比
3.3 FedSGD:
典型的联邦学习场景是在本地客户端设备负 责存储和处理数据的约束下,只上传模型更新的 梯度信息,在数千万到数百万个客户端设备上训 练单个全局模型 w。中心服务器的目标函数 F(w) 通常表现为:
其中,m 是参与训练的客户端设备总数,n 是所 有客户端数据量总和, nk 是第 k 个客户端的数据 量, Fkw是第 k 个设备的本地目标函数。
其中, dk 是第 k 个客户端的本地数据集, fi(w)=α(xi,yi,w)是具有参数 w 的模型对数据集dk中的实例(xi,yi) 产生的损失函数。dk中所有实 例产生的损失函数之和除以客户端 k 的总数据量 就是本地客户端的平均损失函数,损失函数与模 型精度成反比。机器学习的目标函数优化通常是让损失函数达到最小值。
联邦学习的目标函数优化算法中,通常采用 大批量随机梯度下降(SGD)算法,即通过本地 客户端模型训练的损失函数,乘以固定的学习率 η,计算出新一轮的权重更新。因此,本地客户 端的模型权重更新如下:
第 t 轮通信中心服务器的模型聚合更新如下:
3.4 FedAvg:
前文所言,FedSGD方法在实践中可以很好地训练出高精度模型,但是会导致客户端和用户间的通信次数大大增多,模型训练效率慢。
由于硬件设备的提升,计算资源冗余 因此可使用额外的计算来减少训练模型所需的通信轮数:
1.增加并行性,我们在每个通信回合之间使用更多独立工作的客户端
2.增加了每个客户机上的计算量,每个客户机在每个通信轮之间执行更复杂的计算,而不是执行像梯度计算这样的简单计算。
本文在并行数量最低限制的基础上探究增加客户机的计算量——FedAvg使用相对较少的沟通来训练高质量的模型。
相较于FedSGD模型,FedAvg提升了每轮次模型更新中本地的梯度更新次数,进而减少了模型更新的轮数。
更新过程如图所示:
计算量由三个关键参数控制:
C,每轮执行计算的客户端比例; C*K为每轮次的客户端总数
E,每轮次每个客户端在其本地数据集上通过的训练次数(遍历数据集的次数)。
B,用于客户端更新的本地数据最小批的大小。(其导数为数据集分成的份数)
则uk表示每轮更新时,客户端k在本地进行模型权重更新的次数。
当B = ∞,E = 1时,FedAvg和FedSGD等价。nk/B=1,uk=1,代表每轮更新时,客户端本地模型参数仅仅更新一次。
B = ∞,代表每一份的数据量无限接近于本地总数据量。则nk/B趋近于1,而不是0。(从实际含义理解而非简单的数学公式)