联邦知识蒸馏概述与思考(续)

关注公众号,发现CV技术之美

前文(【科普】联邦知识蒸馏概述与思考)提到知识蒸馏是一种模型压缩方法,通过利用复杂模型(Teacher Model)强大的表征学习能力帮助简单模型(Student Model)进行训练,主要分为两个步骤:

1)提取复杂模型的知识,在这里知识的定义有很多种,可以是预测的logits、模型中间层的输出feature map、也可以是模型中间层的attention map,主要就是反映了教师模型的学习能力,是一种表征的体现;

2)将知识迁移/蒸馏到学生模型中去,迁移的方式也有很多种,主要是各种loss function的实现,有L1 loss、L2 loss以及KL loss等手段。

知识蒸馏可以在保证模型的性能前提下,大幅度的降低模型训练过程中的通信开销和参数数量,知识蒸馏的目的是通过将知识从深度网络转移到一个小网络来压缩和改进模型。

这很适用于联邦学习,因为联邦学习是基于服务器-客户端的架构,需要确保及时性和低通信,因此最近也提出很多联邦知识蒸馏的相关论文与算法的研究,接下来我们基于算法解析联邦蒸馏学习。

 FL-FD 数据增强的联邦蒸馏算法【1】

在联邦学习(Federated Learning: FL)中,在每个设备端执行训练过程需要与模型大小成比例的通信开销,从而禁止使用大型模型,因此,作者寻求在非IID私有数据下可以实现通信高效的设备上ML方法。

作者提出联邦蒸馏(FD)算法,这是一种分布式在线知识蒸馏方法,其通信有效成本的大小不取决于模型大小,而取决于输出尺寸。在进行联邦蒸馏之前,我们通过联邦增强(FAug)来纠正非IID训练数据集。

这是一种使用生成对抗网络(GAN)进行的数据增强方案,该数据增强方案在隐私泄露和通信开销之间可以进行权衡取舍。经过训练的GAN可以使每个设备在本地生成所有设备的数据样本,从而使训练数据集成为IID分布。

联邦蒸馏(FD):在FD中,每台设备都将自己视为学生,并将其他所有设备的平均模型输出视为其老师的输出。每个模型输出是一组通过softmax函数归一化后的logit值,此后称为logit向量,其大小由标签数给出。

使用交叉熵来周期性地测量师生的输出差异,交叉熵成为学生的损失调整器,称为蒸馏调整器,从而在培训过程中获得其他设备的知识,具体损失是:KDLoss(Local_Logit,Global_Logit)+CELoss(Local_Logit,Local_Lable)。FD中的每个设备都存储着本地每个标签的平均logit向量,并定期将这些本地平均logit向量上载到服务器。

服务器将从所有设备上载的本地平均Logit向量平均化,从而得出每个标签的全局平均Logit

  • 2
    点赞
  • 26
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值