Group Knowledge Transfer:Federated Learning of Large CNNs at the Edge论文阅读

Group Knowledge Transfer:Federated Learning of Large CNNs at the Edge论文阅读

总述

年份/期刊:NIPS/2020
简单总结:

  • FedGKT是为了解决边缘设备算力不够而提出的模型,结合FedAvg、Spilt Learning(SL)、知识蒸馏(KD)三方思想建立的模型。客户端设计为只包含特征提取器和分类器的紧凑CNN,其传给服务器特征。服务器利用GT和客户端预测的软标签进行KD训练,然后将自己预测的软标签传给客户端,客户端进行本地KD训练。最终实现的模型是边缘客户端特征提取器、共享服务器端模型的组合
  • 主要依托SL框架,结合FedAcg本地多次更新的思想,并且通过知识蒸馏来弥补本地多次更新带来的精度损失,利用AM思想来约束客户端和服务器端

概览

动机:

  • FL需要边缘设备的支持,CNN模型深度越大精度越好—>边缘设备算力不够—>FedAvg算法只能用于小CNN模型或者忽略边缘设备的算力问题,不现实—>解决计算局限性,使用基于模型并行的分裂学习Split Learning (SL) —>SL存在严重的离散问题,需要多次通信,通信成本大—>结合FL和SL,如何在减少通信成本的情况下达成SL效果–>知识蒸馏

优点:

  • 结合SL优点,客户端不需要过多的内存和算力
  • 结合FedAvg优点,减少通信频率
  • 像SL一样,交换隐藏特征降低了通信带宽需求♥♥
  • 支持异步训练

结合SL和FL的难点

  • 客户端如何利用本地小数据集训练特征提取器
  • 客户端和服务器端的loss相辅相成,相互依赖

背景知识

  1. SL是什么,为什么SL能够解决计算局限性?
    SL,分裂学习/分割学习,和FL构成目前两种常见的分布式机器学习方法。FL的主要缺点是每个客户端都需要运行完整的ML模型。
    因此SL也有利用价值,SL主要特点是,机器学习模型架构在客户端和服务器之间拆分。因此更加安全,同时也成为资源受限的环境的更好选择。但是缺点是由于跨多个客户端的基于中继的训练,SL 的性能比 FL 慢。
    SL试图通过将模型(参数)分成两部分,并将较大的部分加载到服务器端来打破计算约束,但单个小批迭代需要远程正向传播和反向传播。对于边缘计算来说,这种高频率的同步机制可能会导致严重的离散问题,极大地减慢了训练过程。
  2. 传统知识蒸馏,本文的KD和传统KD差异
    参考资料:https://baijiahao.baidu.com/s?id=1673896462976965754&wfr=spider&for=pc
    主要目的:压缩模型
    主要流程
    在这里插入图片描述
    知识蒸馏的过程将包括以下几个步骤:
  • 训练教师模型:首先,使用原始数据集和交叉熵损失函数来训练一个高精度的教师模型。
  • 生成软标签:使用训练好的教师模型对原始数据集进行预测,得到软标签。通常情况下,软标签是由教师模型的输出结果通过某种温度调节方法获得的。
  • 训练学生模型:使用原始数据集和包含软标签的损失函数来训练一个学生模型。学生模型的损失函数通常由两个部分组成:一部分是与教师模型的软标签之间的距离,另一部分是与真实标签之间的距离。在训练过程中,学生模型不仅要尽可能地拟合软标签,还要尽可能地接近真实标签。
  • 测试学生模型:在测试集上评估学生模型的性能。学生模型的目标是在接近教师模型的性能的同时,具有更小的模型大小和更高的推理速度。

软标签:
在这里插入图片描述
温度调节,蒸馏:
在这里插入图片描述
本文的KD与先前KD的不同之处:每个客户端只能访问自己的独立数据集。以前的方法使用集中训练,但我们使用交替训练方法

系统设计

系统总览

在这里插入图片描述
客户端架构:简单的特征提取器+分类器
服务器端架构:CNN训练模型
Why:文章按照SL的想法进行构建,则应是客户端特征提取+服务器端CNN其他,这样构成了完整的CNN-SL模型。在客户端添加简单的分类器,构成本地可训练的小模型,应是为了结合FL和KD,协调客户端和服务器端的模型收敛
过程:

  1. 客户端使用本地数据集进行训练,传输特征和软标签
  2. 服务器端使用特征作为输入来进行训练,然后用基于KD的loss函数最小化GT和软标签的误差,传输自己预测的软标签
  3. 客户端使用服务器端传来的软标签,基于KD的loss进行训练
    在这里插入图片描述

具体过程

服务器端和客户端的loss表示为:
在这里插入图片描述
CE:预测值和GT之间的交叉熵损失
KL:KL散度
ps,pk:分别是服务器端和客户端的概率预测
直观来看,通过KL来尽量让软标签和GT彼此接近,在这个过程中,服务器端从每个边缘模型获得的知识。类似地,边缘模型试图使其预测更接近服务器模型的预测,从而吸收服务器模型知识来提高其特征提取能力。
接着,使用改进的AM算法,来交替优化两个变量。这里主要指的是,固定服务器端参数不变,更新几轮每个客户端的参数。然后固定客户端的参数不变,然后更新几轮服务器端的参数。
在这里插入图片描述
在这里插入图片描述

实验

显卡:4张 2080ti
数据集:CIFAR-10,CIFAR-100 ,CINIC-10
Baseline:FedAvg
实验的模型架构:

  • FedAvg:ResNet-56和ResNet-110
  • FedGKT:客户端为ResNet-8,服务器端为ResNet-55和ResNet-109

实验1:模型准确度实验

有16个客户端和一个服务器端
在这里插入图片描述
Centralized(集中式)指的是一种基于中心化服务器的传统机器学习方法,其中所有的数据集和模型都存储在中央服务器上,所有参与方在每轮训练期间都要将其本地数据上传到中央服务器进行模型更新
在这里插入图片描述
结论:

  • non-IID通常比IID精度要差一点
  • FedGKT和FedAvg的效果类似

实验2:计算和通信效率

采用论文[1-2]中的方式,通过浮点数运算来表示运算成本:
在这里插入图片描述
边缘设备的计算成本和内存需求主要通过本文设计的ResNet-8来减少
在这里插入图片描述
通信效率方面,与同样交换特征的SL进行对比

消融实验

KD知识蒸馏的有效性

在这里插入图片描述
只使用CE而不是用KD,发散
客户端使用KD而服务器端不使用,效果好
都是用KD,在复杂数据集上效果好

异步性

在这里插入图片描述
异步性不会导致模型精度下降很多,提到SL很难进行异步

边缘节点数目对精度的影响

在这里插入图片描述
随着数目的增多,精度不会下降很多

边缘设备模型规模的影响

在这里插入图片描述
随着边缘设备的网络越来越简单,效果有下降

总结与收获

  1. 传输数据的成本问题
    引用原文:与整个模型权重或梯度相比,隐藏向量肯定要小得多(例如ResNet-110的隐藏向量大小约为64KB,而32x32图像的整个梯度/模型大小为4.6MB)。即使在高分辨率视觉任务设置中,这一观察结果也成立(例如,当图像大小为224x224时,隐藏特征图的大小仅为1Mb,而ResNet的大小为100Mb)。
  2. 什么情况下可以传输特征
    本文提到,传统的SL传输的一般是特征图,而传统的SL模型实际上是把一个完整的模型进行分割,传统的FL实际上是各方都保存有完整的可训练/不可训练模型
    因此,传输特征信息来自于SL的知识体系,若我们的客户端没有完整的模型,这时传特征符合逻辑一些
    也就是说,传特征时,我们服务器端必须有可训练的,可以学习特征的模型,而这时客户端不一定需要很复杂很完整的模型
  3. 存在的未解决问题
    网络的可拓展性,以及通信还有优化的空间,而且未对设备进行公平化处理
  4. 未来发展
    文章提到这种客户端和服务器端模型组合产生作用的方式,有可以利用到个性化联邦学习中。例如,我们可以对客户端模型进行多次微调,以查看这种个性化的客户端模型和服务器模型的组合是否更有效。

参考文献

[1] Hernandez, D., T. B. Brown. Measuring the algorithmic efficiency of neural networks. arXiv preprint arXiv:2005.04305, 2020.
[2] Wang, H., K. Sreenivasan, S. Rajput, et al. Attack of the tails: Yes, you really can backdoor federated learning. arXiv preprint arXiv:2007.05084, 2020.

在“尾数攻击:是的,你真的可以后门联合学习”这个问题中,尾数攻击是指通过篡改联合学习模型中的尾部数据,来影响模型的训练结果以达到攻击的目的。 联合学习是一种保护用户隐私的分布式学习方法,它允许设备在不共享原始数据的情况下进行模型训练。然而,尾数攻击利用了这种机制的漏洞,通过对局部模型的微小篡改来迫使全局模型在联合学习过程中产生误差。 在尾数攻击中,攻击者可以修改尾部数据的标签、特征或权重,以改变训练模型。这可能导致全局模型在聚合本地模型时出现错误,从而得到错误的预测结果。攻击者可以利用这种攻击方式来干扰或扭曲联合学习任务的结果。 为了解决尾数攻击,可以采取以下措施: 1. 发现和识别攻击:通过监控和分析联合学习模型的训练过程,可以检测到异常的模型行为。例如,检查模型的准确性变化、每个本地模型的贡献以及全局模型与本地模型之间的差异。 2. 降低攻击影响:可以采用如去噪、增加数据量、增强模型鲁棒性等方法来减轻尾数攻击的影响。 3. 鉴别合法参与者:在联合学习任务中应对参与者进行身份认证和授权,并且限制恶意攻击者的参与。这样可以减少尾数攻击的潜在风险。 4. 加强安全机制:引入加密技术和鲁棒算法来保护联合学习过程中的数据和模型,防止未经授权的篡改。 综上所述,尾数攻击是一种可能出现在联合学习中的安全威胁。为了保护联合学习任务的安全性和可靠性,需要采取有效的措施来识别、减轻和预防尾数攻击。
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值