文章目录
Group Knowledge Transfer:Federated Learning of Large CNNs at the Edge论文阅读
总述
年份/期刊:NIPS/2020
简单总结:
- FedGKT是为了解决边缘设备算力不够而提出的模型,结合FedAvg、Spilt Learning(SL)、知识蒸馏(KD)三方思想建立的模型。客户端设计为只包含特征提取器和分类器的紧凑CNN,其传给服务器特征。服务器利用GT和客户端预测的软标签进行KD训练,然后将自己预测的软标签传给客户端,客户端进行本地KD训练。最终实现的模型是边缘客户端特征提取器、共享服务器端模型的组合
- 主要依托SL框架,结合FedAcg本地多次更新的思想,并且通过知识蒸馏来弥补本地多次更新带来的精度损失,利用AM思想来约束客户端和服务器端
概览
动机:
- FL需要边缘设备的支持,CNN模型深度越大精度越好—>边缘设备算力不够—>FedAvg算法只能用于小CNN模型或者忽略边缘设备的算力问题,不现实—>解决计算局限性,使用基于模型并行的分裂学习Split Learning (SL) —>SL存在严重的离散问题,需要多次通信,通信成本大—>结合FL和SL,如何在减少通信成本的情况下达成SL效果–>知识蒸馏
优点:
- 结合SL优点,客户端不需要过多的内存和算力
- 结合FedAvg优点,减少通信频率
- 像SL一样,交换隐藏特征降低了通信带宽需求♥♥
- 支持异步训练
结合SL和FL的难点
- 客户端如何利用本地小数据集训练特征提取器
- 客户端和服务器端的loss相辅相成,相互依赖
背景知识
- SL是什么,为什么SL能够解决计算局限性?
SL,分裂学习/分割学习,和FL构成目前两种常见的分布式机器学习方法。FL的主要缺点是每个客户端都需要运行完整的ML模型。
因此SL也有利用价值,SL主要特点是,机器学习模型架构在客户端和服务器之间拆分。因此更加安全,同时也成为资源受限的环境的更好选择。但是缺点是由于跨多个客户端的基于中继的训练,SL 的性能比 FL 慢。
SL试图通过将模型(参数)分成两部分,并将较大的部分加载到服务器端来打破计算约束,但单个小批迭代需要远程正向传播和反向传播。对于边缘计算来说,这种高频率的同步机制可能会导致严重的离散问题,极大地减慢了训练过程。 - 传统知识蒸馏,本文的KD和传统KD差异
参考资料:https://baijiahao.baidu.com/s?id=1673896462976965754&wfr=spider&for=pc
主要目的:压缩模型
主要流程:
知识蒸馏的过程将包括以下几个步骤:
- 训练教师模型:首先,使用原始数据集和交叉熵损失函数来训练一个高精度的教师模型。
- 生成软标签:使用训练好的教师模型对原始数据集进行预测,得到软标签。通常情况下,软标签是由教师模型的输出结果通过某种温度调节方法获得的。
- 训练学生模型:使用原始数据集和包含软标签的损失函数来训练一个学生模型。学生模型的损失函数通常由两个部分组成:一部分是与教师模型的软标签之间的距离,另一部分是与真实标签之间的距离。在训练过程中,学生模型不仅要尽可能地拟合软标签,还要尽可能地接近真实标签。
- 测试学生模型:在测试集上评估学生模型的性能。学生模型的目标是在接近教师模型的性能的同时,具有更小的模型大小和更高的推理速度。
软标签:
温度调节,蒸馏:
本文的KD与先前KD的不同之处:每个客户端只能访问自己的独立数据集。以前的方法使用集中训练,但我们使用交替训练方法
系统设计
系统总览
客户端架构:简单的特征提取器+分类器
服务器端架构:CNN训练模型
Why:文章按照SL的想法进行构建,则应是客户端特征提取+服务器端CNN其他,这样构成了完整的CNN-SL模型。在客户端添加简单的分类器,构成本地可训练的小模型,应是为了结合FL和KD,协调客户端和服务器端的模型收敛
过程:
- 客户端使用本地数据集进行训练,传输特征和软标签
- 服务器端使用特征作为输入来进行训练,然后用基于KD的loss函数最小化GT和软标签的误差,传输自己预测的软标签
- 客户端使用服务器端传来的软标签,基于KD的loss进行训练
具体过程
服务器端和客户端的loss表示为:
CE:预测值和GT之间的交叉熵损失
KL:KL散度
ps,pk:分别是服务器端和客户端的概率预测
直观来看,通过KL来尽量让软标签和GT彼此接近,在这个过程中,服务器端从每个边缘模型获得的知识。类似地,边缘模型试图使其预测更接近服务器模型的预测,从而吸收服务器模型知识来提高其特征提取能力。
接着,使用改进的AM算法,来交替优化两个变量。这里主要指的是,固定服务器端参数不变,更新几轮每个客户端的参数。然后固定客户端的参数不变,然后更新几轮服务器端的参数。
实验
显卡:4张 2080ti
数据集:CIFAR-10,CIFAR-100 ,CINIC-10
Baseline:FedAvg
实验的模型架构:
- FedAvg:ResNet-56和ResNet-110
- FedGKT:客户端为ResNet-8,服务器端为ResNet-55和ResNet-109
实验1:模型准确度实验
有16个客户端和一个服务器端
Centralized(集中式)指的是一种基于中心化服务器的传统机器学习方法,其中所有的数据集和模型都存储在中央服务器上,所有参与方在每轮训练期间都要将其本地数据上传到中央服务器进行模型更新
结论:
- non-IID通常比IID精度要差一点
- FedGKT和FedAvg的效果类似
实验2:计算和通信效率
采用论文[1-2]中的方式,通过浮点数运算来表示运算成本:
边缘设备的计算成本和内存需求主要通过本文设计的ResNet-8来减少
通信效率方面,与同样交换特征的SL进行对比
消融实验
KD知识蒸馏的有效性
只使用CE而不是用KD,发散
客户端使用KD而服务器端不使用,效果好
都是用KD,在复杂数据集上效果好
异步性
异步性不会导致模型精度下降很多,提到SL很难进行异步
边缘节点数目对精度的影响
随着数目的增多,精度不会下降很多
边缘设备模型规模的影响
随着边缘设备的网络越来越简单,效果有下降
总结与收获
- 传输数据的成本问题
引用原文:与整个模型权重或梯度相比,隐藏向量肯定要小得多(例如ResNet-110的隐藏向量大小约为64KB,而32x32图像的整个梯度/模型大小为4.6MB)。即使在高分辨率视觉任务设置中,这一观察结果也成立(例如,当图像大小为224x224时,隐藏特征图的大小仅为1Mb,而ResNet的大小为100Mb)。 - 什么情况下可以传输特征
本文提到,传统的SL传输的一般是特征图,而传统的SL模型实际上是把一个完整的模型进行分割,传统的FL实际上是各方都保存有完整的可训练/不可训练模型
因此,传输特征信息来自于SL的知识体系,若我们的客户端没有完整的模型,这时传特征符合逻辑一些
也就是说,传特征时,我们服务器端必须有可训练的,可以学习特征的模型,而这时客户端不一定需要很复杂很完整的模型 - 存在的未解决问题
网络的可拓展性,以及通信还有优化的空间,而且未对设备进行公平化处理 - 未来发展
文章提到这种客户端和服务器端模型组合产生作用的方式,有可以利用到个性化联邦学习中。例如,我们可以对客户端模型进行多次微调,以查看这种个性化的客户端模型和服务器模型的组合是否更有效。
参考文献
[1] Hernandez, D., T. B. Brown. Measuring the algorithmic efficiency of neural networks. arXiv preprint arXiv:2005.04305, 2020.
[2] Wang, H., K. Sreenivasan, S. Rajput, et al. Attack of the tails: Yes, you really can backdoor federated learning. arXiv preprint arXiv:2007.05084, 2020.