摘要
主要工作:
提出了一个联邦元学习框架 FedMeta,允许以更灵活的方式共享参数化算法(或元学习器),同时保留客户端隐私,而无需收集到服务器上的数据。
优势:
- 收敛速度更快,通信成本降低。
- 与联邦学习中领先的优化算法联邦平均(FedAvg)相比,准确率提高了3.23%-14.84%。
引言
联邦学习存在的两大挑战:
对于统计挑战,分散的数据是非IID、高度个性化和异构的,从而导致模型准确度显著降低 。
对于系统性挑战,设备的数量通常比传统的分布式设置多几个数量级。此外,每个设备在存储、计算和通信能力方面可能会有显著约束。
联邦学习的不足之处:
联邦学习需要训练一个大型n-way分类器来利用来自所有客户机的数据,而其实仅k-way分类器就足够了,因为它每次只对一个客户机进行预测。庞大的模型增加了通信和计算成本。可以只向客户端发送模型的一部分来更新相关参数,但这需要事先了解客户端的私有数据来决定该部分。
元学习的优势:
在元学习中,目标是在一系列任务上学习一个模型,以便它可以仅使用少量样本来解决新任务
基于初始化的元学习算法,如MAML,在快速适应和在新任务上具有良好泛化性方面是众所周知的,这使得它特别适用于分布式训练数据非IID和高度个性化的联邦设置。
在元学习中,算法可以训练包含不同类别的任务。例如,模型不可知元学习(Model-Agnostic Meta-Learning, MAML)算法,可以通过对k-way任务进行元训练来提供k-way分类器的初始化,而不考虑具体的类别。
相关工作
受上述启发,该论文提出了一种联邦元学习框架,将元学习方法学和联邦学习相结合。在联邦元学习框架中,可以使用MAML对所有n个类别的k-way分类器初始化进行元训练。通过这种方式,联邦元学习具有相当低的通信和计算成本。
联邦元学习:
联邦元学习框架将每个客户端视为一个任务。我们的目标不是训练一个摄取所有任务的全局模型,而是训练一个良好初始化的模型,可以快速适应新任务。元学习算法的核心思想是提取和传播先前任务的内部可传递表示。因此,它们可以防止过拟合,并提高对新任务的泛化能力,从而显示出应对联邦设置中的统计和系统挑战的潜力。
联邦元学习框架流程:
相比之下,在联邦元学习中,维护服务器上的算法,并将其分发给客户端进行模型训练。在元训练的每个episode中,一批抽样的客户端接收到算法的参数,并执行模型训练。然后将查询集的测试结果上传到服务器以更新算法。FedMeta框架的工作流程如图1所示。
以 MAML 作为一个运行示例,我们希望使用所有客户端的数据一起训练模型的初始化。MAML 包含两层优化:
内层循环使用维护的初始化来训练特定任务的模型;
外层循环使用任务的测试损失来更新初始化。
在联邦设置中,每个客户端 u 从服务器检索初始化 θ,使用设备上的支持集的数据训练模型,并将测试损失
在一个独立的查询集
上发送给服务器。服务器维护初始化,并通过收集来自一小批客户端的测试损失来更新初始化。
此过程中传输的信息包括模型参数初始化(从服务器到客户端)和测试损失(从客户端到 服务器),无需将数据收集到服务器。
联邦元学习与联邦学习的区别:
唯一的区别是服务器和客户端之间传输的是(参数化的)算法而不是全局模型。
联邦元学习
元学习方法:
元学习的目标是元训练一个算法A,以便能够快速训练模型,例如深度神经网络,用于新任 务。算法Aφ通常是参数化的,其中其参数φ在元训练过程中使用一组任务进行更新。
元训练中的一个任务T包括一个支持集 和一个查询集
,两者都包含标记的数据点。算法A 在支持集上训练模型f,并输出参数
,我们称之为内更新。然后,模型
在查询集
上进行评估,计算出一些测试损失
来反映Aφ的 训练能力。最后,更新Aφ以最小化测试损失,我们称之为外更新。
元学习训练过程:
在元学习中,通过元训练过程缓慢地从大量任务中学习一个参数化的算法(或元学习器),该算法在每个任务中快速训练一个特定模型。一个任务通常包含一个互斥的支持集和一个查询集。一个特定任务的模型在支持集上进行训练,然后在查询集上进行测试,并使用测试结果来更新算法。
对于服务器端而言:
初始化参数θ(MAML)或者参数θ和超参数α(Meta-SGD)。
对于每一轮episode t,执行以下步骤:
a. 从所有客户端中随机选择一个大小为m的子集,并将参数θ(MAML)或者参数θ和超参数α(Meta-SGD)分发给这些客户端。
b. 并行地对选定的每个客户端u∈执行以下步骤:
i. 使用参数θ对客户端的本地数据进行模型更新,计算测试损失gu。
对于MAML算法,执行 ModelTrainingMAML(θ);
对于Meta-SGD算法,执行 ModelTrainingMetaSGD(θ, α)。
c. 在所有客户端上计算完测试损失后,更新算法的参数。
对于MAML算法,通过取均值更新参数θ;
对于Meta-SGD算法,通过取均值 更新参数θ和超参数α。
对于客户端而言:
对于MAML:
在客户端u上执行MAML的模型训练过程。
从客户端u的数据中采样支持集(support set)Du S 和查询集(query set)Du Q。
计算支持集的损失函数LDu S (θ),这里使用了支持集上的损失函数的平均值。
根据计算得到的LDu S (θ)和学习率α,更新参数θ得到客户端u的临时参数θu。
计算查询集的损失函数LDu Q (θu),同样使用了查询集上的损失函数的平均值。
计算参数θu关于LDu Q (θu)的梯度gu,即对θu求梯度。
将计算得到的梯度gu返回给服务器端,用于全局模型参数的更新。
对于元-SGD:
从支持集(support set)Du S 和查询集(query set)Du Q 中采样数据。
计算支持集的损失函数LDu S (θ),这里使用了支持集上的损失函数的平均值。
使用学习率α,通过对θ进行迭代更新,计算出局部更新后的参数θu。
计算查询集的损失函数LDu Q (θu),同样使用了查询集上的损失函数的平均值。
计算关于θ和α的梯度 ∇(θ,α)LDu Q (θu),即对θ和α同时求梯度。
将计算得到的梯度gu返回给服务器端,用于全局模型参数的更新。
实验
在LEAF数据集(联邦学习领域的基准数据集)上进行了实验,展示了与传统联邦学习方法相比,FedMeta在收敛速度、准确性和系统开销方面提供了更快、更高和更低的性能。
LEAF由 三个数据集组成:
(1)62类图像分类任务的FEMNIST数据集;
(2) 用于下一个字符预测的莎士比亚数据集;
(3)用于2类情感分类的Sentiment140数据集。
实验设置:
在所有的实验中,我们随机选择80%的客户作为训练客户,选择10%的客户作为验证客户, 剩余的作为测试客户,因为我们认为能够推广到新客户的能力是联邦学习的一个关键属性。
如图2所示,FedMeta框架中的所有方法均实现了更快速、更稳定的收敛,并提高了最终的准确性,表明FedMeta框架具有更好的泛化能力,并且能够有效地适应具有有限数据的新客户端。
总结
提出了一种新颖的联邦元学习框架FedMeta。
作者认为未来工作:
(1)可以在研究FedMeta框架在从模型攻击角度保护用户隐私方面是否具有额外的优势,因为目前的联邦学习方法中共享的全局模型仍然明示了所 有用户的隐私,而在FedMeta中,是共享了元学习器;
(2)我们将在线部署我们的FedMeta框架进行APP推荐以评估其在线性能,这涉及许多工程工作尚待完成。