元学习能快速适应并良好泛化新任务,和分散的训练数据是非独立同分布且高度个性化的联邦学习很契合。首先简单了解元学习的步骤(默认大家都提前了解了),但是对其与联邦学习的结合非常陌生,我选了一篇比较有代表性的论文分别阅读与理解,重点关注整个训练流程。
Federated Meta-Learning with Fast Convergence and Efficient Communication(华为诺亚方舟实验室)
总的来说:这篇文章提出了一个联邦元学习框架 FedMeta,其中共享参数化算法(或元学习器),而不是联邦学习中的全局模型。目标是通过一起使用所有客户的数据来训练模型的初始化。
在联邦元学习过程中,服务器上维护一个算法(元学习器),将此算法分发给客户端进行模型训练。 在元学习每一个episode里,一批采样的客户端接收算法的参数并进行模型训练。 然后将query集上的测试结果上传到服务器进行算法更新。训练框架如下图。当元学习器训练好后,客户端能够快速训练模型。
注意:服务器和客户端之间传输的信息是算法(元学习器参数)而不是全局模型。如果训练分类器,训练全局模型必须要训练全类型分类器(比如有10类数据),但是元学习是每个用户有多少类就训练多少类(比如就只有5类数据),计算通信开销会下降。
下面详看算法:分内外更新,内更新在support集上训练模型f,输出参数
θ
{\theta}
θ;外更新让模型f在query集上评估,计算loss反应训练能力,最小化损失更新。
ModelTraningMAML和MetaSGD可以任选一个,相应的伪代码内容为:
- MAML
算法A用于模型初始化,对于每个任务T,算法保持 θ {\theta} θ是模型f的参数初始值。之后 f θ f_{\theta} fθ在支持集上训练,用损失通过一次或多粗梯度下降步骤更新至 θ T \theta_T θT。最后在query集上测试并计算测试损失。 - Meta-SGD
在MAML基础上,Meta-SGD更进一步同时学习初始化 θ \theta θ和内部学习率 α \alpha α。可以发现测试损失能被看作是 θ \theta θ和 α \alpha α的函数。 θ \theta θ和 α \alpha α向量维度相同。
最后,训练好的元学习器能够应用于具体的任务
u
u
u,初始化参数用
u
u
u的训练集更新,更新所得的参数
θ
u
\theta_u
θu能够被用于做预测。