联邦元学习笔记,A Collaborative Learning Framework via Federated Meta-Learning

联邦元学习笔记,A Collaborative Learning Framework via Federated Meta-Learning

一个基于联邦元学习的的合作学习框架,笔记。和原文一起看

摘要

边缘物联网设备需要实时智能决策,但是由于计算资源限制和本地数据限制往往做不到。于是提出一个目标是作为平台的合作学习框架。一个模型首先在一个客户集合上训练,然后在目标边缘节点上快速适应,只用很少的样本。同时也调查了算法的收敛性(温和条件下)和目标节点的适应性能。同时为了对抗脆弱的元学习算法的对抗攻击,又提出了更稳健的版本,基于分布健壮优化,同样分析了收敛性。不同数据集上的实验证明了其有效性。

1. 简介

现在的边缘物联网设备很多都要求实时边缘智能,传统送数据到云服务器分析不现实(高延迟高带宽消耗)。我们希望数据本地存储处理。显然由于孱弱的性能和资源能源限制,直接运行AI表现必然是灾难性的。专门研究这个领域的叫边缘智能。

可以利用边缘节点的相似性解决挑战。因此本文提出一个platform-aided平台辅助的合作学习框架,在其中模型的知识分布式地由联邦节点学习,然后通过平台输送到目标节点利用本地数据集微调。

已经有实验说明使用目标节点上的小数据集微调全剧模型效果有限。另一条路线上联邦多任务学习已经被提出,用于给不同节点训练不同但相关的模型,从而解决边缘节点模型异构性。每一个源节点同时也是目标节点,可以通过利用其他节点的数据和计算资源训练自己独特的模型。显然这需要额外时间和计算开支,不适用实时边缘智能。

基于最近的元学习,为了解决“学什么”的问题,提出一种联邦元学习。元学习背后的依据在于通过许多子任务训练模型的初始参数,从而让模型学到背后的知识,这样可以用在其他任务中,提供较好的模型初始化等参数。在这个任务中我们在所有源节点上进行元学习,这样在目标节点只需要很少样本就可以达到最佳效果,进而实现实时边缘智能。

联邦学习不存在元学习中的已知分布假设,不同节点有相似特点的不同本地模型,可以自动化任务创建因为每个节点都在持续做出智能决策。进而也不需要中心化计算,提供通讯开销和本地计算开销之间的较好灵活取舍(如控制本地更新次数)。

元学习:Meta-learning,让机器学会如何去学习learning to learn。对于数据和标签(x,y),不是直接学习y=f(x),而是学习f=F(x,y)
https://zhuanlan.zhihu.com/p/521307395

主要贡献:

  1. 开创性使用平台辅助合作学习框架,首次使用元学习在分布式结构上获得实时边缘智能
  2. 研究本算法的收敛性,调查了目标节点上模型适应性。首次研究节点相似性和本地更新造成收敛和通讯计算开销影响。在边缘节点上添加了梯度的限制,使用Hessians本地损失函数,解决了分布一致的限制
  3. 使用分布式健壮优化DRO对抗可能存在的元学习漏洞,截至写作为止是首次研究DRO对元学习健壮性
  4. 在不同的数据集上证实有效性。

2. 相关工作

MAML是基于梯度的算法,只用一次更新效果也不错。为了避免前者中的二次倒数,提出了Reptile这种一阶方法,类似于joint learning(共同学习)但是在元学习上效果也很好。关于中心化MAML的非凸函数收敛性也有新文章研究。对抗样本会导致元学习性能大幅下降,ADML算法可以同时利用这两种样本让内部梯度更新和元更新掰手腕,可惜的是难以实现。本文基于DRO的算法更稳定且可以抗更常见干扰(如分布外样本)同时可以调节程度。

联邦学习提出后,提出了FedProx为了解决数据异构性同时表征非凸损失函数的行为。一般的联邦学习并不打算在小样本上学习,联邦元学习则立志于保持节点独立性,同时小数据集训练也能保持较好表现。

元学习和多任务学习都是充分利用相关任务,提高学习效果。但是元学习关注小样本上快速性能表现,多任务学习旨在让源和目标同时学习并准确学习。同时后者需求的样本也多得多。

联邦元学习示意图

3. 实时边缘智能的联邦元学习

设置:S个有自己任务的源节点,一个不在S中的目标节点t。元模型以分布式样式在元学习阶段中训练。

A. 问题定义

假设,每个节点上的任务符合一个函数 f θ f_\theta fθ θ \theta θ是参数集合,实数。假设每个数据集中数据有自己的分布 P i P_i Pi,这样损失函数记录为 l ( θ , ( X j , y j ) ) l(\theta,(X^j,y^j)) l(θ,(Xj,yj))。可以对节点i得到经验损失函数L,即所有点的损失之和除以数据集大小。那么S全体的损失函数就是上述各个节点损失函数加权相加(权重为数据集占总数据比例)。

和MAML保持一致,认为目标t节点有K样本,源节点i样本有训练集和测试集,且训练集大小为K。然后利用在训练集上的损失函数求导得到的梯度对 θ \theta θ只用一次梯度更新,之后计算在测试集上的损失。总体的目标函数是为了最小化带权(和之前一样定义)相加后的测试集合损失。

直觉上,这就是检测模型参数更新对本地测试集上的损失函数的影响,目标为:在模型参数上的小改动(换句话说是损失梯度的方向的改变)会导致任何节点上的任务性能的显著改变。

我们和MAML不同的是我们不要求各个任务服从给定分布。我们使用梯度以及和超参数对应的Hessians本地损失函数的差别来判断节点相似度。不过在不了解数据的情况下平台无法直接解决这个问题。

B. Federated Meta-Learning FedML

每个节点更新 θ \theta θ后上传聚合,同时允许本地更新多步减小通讯开销(往往是瓶颈)。
详细描述:所有S节点在 t = 0 t=0 t=0拥有 θ 0 \theta^0 θ0。之后主要有两步:

  1. 本地更新,根据训练集更新 θ t \theta^t θt,之后根据测试集更新本地模型参数,再更新一次 θ t + 1 \theta^{t+1} θt+1,学习率为元学习率。
  2. 全局聚合,每一个t节点更新 θ t + 1 \theta^{t+1} θt+1到平台,之后平台加权求和得到全局参数,送回各节点。(和普通的联邦学习一模一样)

算法伪代码

快速适应Fast Adoption:得到 θ \theta θ后,同样通过训练集上一步梯度下降获得本地参数。但是不需要通过测试集第二次计算。

相当于:我需要学会如何分辨猫和狗。朋友带着和我相同的疑问开始用猫和狗的照片学习,大致学会以后再用几张照片练手,最后把他学到的知识告诉我

4. 性能分析

A. 收敛分析

为了简化表达,用 G i ( θ ) G_i(\theta) Gi(θ)代表过程中产生的临时损失函数 L i ( ϕ i ( θ ) ) L_i(\phi_i(\theta)) Li(ϕi(θ)) G ( θ ) G(\theta) G(θ)是全局损失函数。

假设1. 每个 L i ( θ ) L_i(\theta) Li(θ)都是强 μ \mu μ凸的
假设2. 每个 L i ( θ ) L_i(\theta) Li(θ)都是H平滑的
假设3. 每个 L i ( θ ) L_i(\theta) Li(θ)的海森都是 ρ \rho ρ李普希兹的
假设4. 每个 L i ( θ ) L_i(\theta) Li(θ)的一阶和二阶梯度与 L w ( θ ) L_w(\theta) Lw(θ)之间的欧氏距离是有上限的

前两个假设很常见并且很多机器学习应用中都存在。假设3是为了更高层次的边缘节点损失函数平滑性,使得表征本地元学习目标成为可能。假设4是为了捕捉节点相似性,所以假设它们有常数上界。特别地,我们认为各个节点的梯度和本地损失函数的海森都有一个常数上界,直觉上这个常数越小表明任务越相似,可以通过大量训练大致理解它。更进一步地,对于一个梯度如上限制的任务分布来说,假设四是必然的,这时假设3、4可以看成典型实现与样本平均之间的距离。简而言之,这里的任务相似性假设比元学习中的所有任务服从已知分布更加实际。要注意的是这些假设没有抛弃元学习的设置。

Hessians:海森(黑塞)矩阵,多元函数的二阶偏导数构成的方阵,描述函数的局部曲率。 H i j = ∂ 2 f ∂ x i ∂ x j ( x ) H_{ij}=\frac{\partial^2 f}{\partial x_i \partial x_j}(x) Hij=xixj2f(x)
https://inst.eecs.berkeley.edu/~ee127/sp21/livebook/def_hessian.html

μ-strongly convex和L-smooth限制了梯度的最大最小变化速度,使得梯度下降变得可控(L-smooth定义了函数“可以凸的”上界和梯度的最大变化速度,μ-strongly convex定义了函数“可以凸的”下界和梯度的最小变化速度,防止梯度变化微弱;综合起来相当于保证了一个函数的梯度变化保持在合理的范围内)。目标函数有了这样的性质,便可以很方便地对其进行收敛性分析,并证明该算法的收敛性。
https://blog.csdn.net/weixin_42534493/article/details/118487431
https://blog.fangzhou.me/posts/20190217-convex-function-lipschitz-smooth-strongly-convex/#lipschitz-smooth

李普希茨条件:|F(X)-F(Y)| <= L*|X-Y|, for all X, Y.限制了函数的改变速率
https://blog.csdn.net/yuyangyg/article/details/78001642

为了表征联邦元学习算法的收敛行为,先研究全局元学习目标 G ( θ ) G(\theta) G(θ)的结构性质,之后研究任务相似度对收敛表现的影响,这些又会被本地多次更新影响从而更加复杂。在这里不进行引理和定理具体的证明和描述,只介绍它的想法和作用。

引理1. 假设1-3成立时,学习率 α \alpha α小于某个值时 G ( θ ) G(\theta) G(θ)是强凸且平滑的。

这表明学习率 α \alpha α较小时, G ( θ ) G(\theta) G(θ)在一步梯度下降时表现得和本地损失函数一样好

定理1. 假设2和4成立时,假设2中的左侧可以找到一个上界。

给定梯度和本地损失函数的海森上界,我们可以找到本地本地元学习目标函数的梯度方差上界,同时保留节点异构性质。

定理2. 基于前述,当假设1-4成立时,T时刻的全局优化目标函数与理想之间的差距有一个上界,表现为初始时刻的全局优化目标函数与理想之间的差值乘上该时刻的某个因子,加上本地多轮更新和任务不相似程度导致的误差。

这说明本地更新次数越大,误差越大,并且影响不小;任务越相似,误差越小。可以通过平衡这两者来达到折衷的目的。

MAML中不允许多轮本地更新

推论1. 在假设成立并且合理选择学习率 α , β \alpha,\beta α,β的情况下,本地更新1轮时,任意时刻的全局优化目标函数与理想之间的差值都要小于等于初始时刻的全局优化目标函数与理想之间的差值乘上该时刻的某个因子

这个推论是显然的,因为本地更新一轮时就和MAML一样了,定理二中后面的本地多轮更新导致的误差变为0

B. 快速适配的性能分析

目标节点t的快速学习性能不仅仅与其本地数据集大小有关,还和它与源节点相似性有关。将 θ c \theta_c θc表示为平台上联邦元学习的输出, θ c ∗ \theta_c^{*} θc作为最佳元学习模型。假设它们之间的距离有上界 ϵ c \epsilon_c ϵc。定义在数据分布 P t P_t Pt上本地平均损失函数为 L t ∗ ( θ ) L_t^*(\theta) Lt(θ),它是各个数据点上的损失函数的期望。

这样我们需要的经验损失函数 L t ( θ ) L_t(\theta) Lt(θ)就是 L t ∗ ( θ ) L_t^*(\theta) Lt(θ)上的样本平均。

定理3. 假设任意一个数据点上的损失函数都是H平滑的,那么对任意大于0的 ϵ \epsilon ϵ都可以以较大的概率得到,优化一步后的模型参数产生的损失函数与最优化参数产生的损失函数之间的距离有上界。

定理三表明任务相似性和本地样本规模产生的影响。特别地最优模型和产生的模型之间的表现差距是有一个和参数之间的距离相关的上界的,这样可以让平台判断源节点和目标节点应当有多相似才能达到给定的学习表现,进而实现目标节点上的边缘智能。

5. 健壮的联邦元学习

研究表明原学习算法很容易受到对抗攻击,导致目标节点性能显著下降。接下来的算法提供了抗攻击性和准确率之间的权衡。

adversarial attack:对抗攻击,对输入样本故意添加一些细微干扰,导致模型以高置信度给出一个错误的输出。
https://zhuanlan.zhihu.com/p/104532285

A. 健壮联邦元学习

我们提出,要从目标节点的训练集获得一个好的模型初始化,它不仅可以对抗 π \pi π距离远离数据集分布 P t P_t Pt的数据,还要保证输入干净数据时效果良好。根据最近的DRO(分布式健壮最优化),可以通过解决一个最小化问题来完成。这个问题可以看成是最小化t轮的本地损失函数加上偏离的数据产生的损失函数期望。
这样联邦元学习的目标函数就需要加上因偏离数据产生的损失函数期望。

DRO:分布式鲁棒优化。基于RO。参数的小扰动会导致最优解不可行,因此通过RO保证解在一个不确定集合中可行。免疫不确定性。
https://arxiv.org/abs/1908.05659 Distributionally Robust Optimization: A Review

B. 基于Wasserstein距离的健壮联邦元学习

使用Wasserstein距离作为分布之间距离的衡量。这里衡量最优化传输代价,也就是任意一对概率测度中距离最小的那对。但是直接使用会导致计算开销很大,因此采用拉格朗日松弛内部的最大化问题,加入一个惩罚因子大于0的 λ \lambda λ。它和 π \pi π成反比。使用Kantorovich对偶解决问题。

Wasserstein距离:衡量P分布需要花多大代价和Q分布保持一致,在机器学习中表现好。也叫推土机代价,相当于将不同凹凸的土堆铲成另一列土需要搬运的土和质量的乘积。
https://zhuanlan.zhihu.com/p/84617531

引理2. 给定分布和测度,可以定义更鲁棒的损失函数。

引理2表示了最坏情况下联合概率测度(对应着将x质量传输给本地最优化问题的传输方案)。可以认为新的问题中损失函数在满足一定条件下是强凸的,可以有效地被梯度下降法解决。

C. 跨节点的鲁棒元学习

每个循环中,节点用特定步数的梯度上升法构建反向数据样本,加入自己的反向数据集中。第一次下降不变,但是第二次的梯度要加上反向数据集上的损失函数后一起梯度下降。

D. 收敛分析

和前面差别不大,省略。通过加入反向扰动样本,学习阶段的节点就可以获得抵抗攻击的能力,同时不显著牺牲学习性能。

6. 实验

A. 实验设置

采用人工合成的数据和现实世界数据(MNIST等)。选择八成节点作为源节点,在剩下的两成上测试适配性能。

B. 效果评估

收敛性: 符合预期,节点相似时误差降低,本地更新次数多时误差上升。在实际应用中非凸设置时仍然可以优质收敛。

快速适配: 显然节点最相似时性能最好;性能远超FedAvg,且节点本地数据集越小其性能差距越大。FedAvg微调时数据越少性能越差,过拟合,FedML改进了这点。

健壮算法快速适配: 首先用干净数据更新元模型,然后分别在干净和脏数据集上测试适配性能。 λ \lambda λ越小算法越稳定。

健壮性和准确性权衡: λ \lambda λ越小,在干净数据集上表现略差,但是脏数据集上表现明显更好。 λ = 0.1 \lambda=0.1 λ=0.1时准确性损失不大但是健壮性远超普通版本。

扰动 ξ \xi ξ 两个版本算法都是扰动越小表现越好,但是健壮算法随着扰动增强效果变得好于普通算法。

7. 总结

  1. 提出了算法并分析收敛性
  2. 提出健壮版本算法并分析收敛性
  3. 实验结果显示了算法的有效性
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值