[联邦学习]FedProx
FedProx(Federalized Proximal Algorithm)是一种在联邦学习(Federated Learning, FL)环境下设计的优化算法,旨在处理数据在不同客户端之间可能存在的不均匀分布(Non-IID Data)的问题。联邦学习是一种机器学习设置,允许多个客户端协作训练一个共享的模型,同时保持数据的隐私和安全,因为数据不需要集中存储或处理。
FedProx是Li Tian等人于2018年(论文链接)所提出的一种针对系统异构性鲁棒的联邦优化算法,发表于MLSys 2020上。它相较于FedAvg主要做出了两点改进:
采样阶段 使用了按数据集大小比例,可放回采样,并直接平均聚合(无加权)来获得无偏梯度估计
本地训练阶段 基于近端项优化的思路,魔改了本地训练的目标函数为
L + μ 2 ∣ ∣ w k , i t − w g l o b a l ∣ ∣ 2 L + \frac{\mu}{2}||w^t_{k,i} - w_{global}||^2 L+2μ∣∣wk,it−wglobal∣∣2
"采样"指的是服务器从参与方(客户端)的数据集中选择样本进行模型更新。因此,在FedProx中,采样是服务器在每轮迭代中从参与方的数据集中按照每个参与方数据集大小的比例进行选择的过程。具体来说,如果某个参与方的数据集更大,则它在采样中被选中的概率更高。
因此,这里的"采样"是指服务器在联邦学习中选择参与方的过程,而不是指参与方选择自己的数据的过程。
背景和问题
在标准的联邦学习模型中,如FedAvg(Federated Averaging),每个客户端独立地在本地数据上训练模型,然后将更新的模型发送给中央服务器。服务器将这些更新平均合并,以更新全局模型。然而,当不同客户端的数据分布差异很大时(即Non-IID),这种简单的平均可能导致模型性能下降,因为它没有考虑到各客户端更新的差异性。
FedProx的工作原理
FedProx在FedAvg的基础上增加了一个正则化项,这个正则化项惩罚模型参数与全局模型参数之间的偏差。具体来说,FedProx的目标是最小化以下目标函数:
L ( w ) = ∑ k = 1 K n k n ( F k ( w ) + μ 2 ∣ w − w t ∣ 2 ) L(w) = \sum_{k=1}^K \frac{n_k}{n} \left( F_k(w) + \frac{\mu}{2} |w - w^t|^2 \right) L(