缓存驱动的联邦学习架构FedCache

FedCache:A Knowledge Cache-driven Federated Learning Architecture for Personalized Edge Intelligence

每日一诗:南陵别儿童入京
唐·李白
白酒新熟山中归,黄鸡啄黍秋正肥。
呼童烹鸡酌白酒,儿女嬉笑牵人衣。
高歌取醉欲自慰,起舞落日争光辉。
游说万乘苦不早,著鞭跨马涉远道。
会稽愚妇轻买臣,余亦辞家西入秦。
仰天大笑出门去,我辈岂是蓬蒿人。

本文是中科院计算所Zhiyuan Wu IEEE Transactions on Mobile Computing(TMC,CCFA)期刊在投的一篇文章阅读笔记。在深入阅读后收获颇丰,遂将对全文的思考和理解写成博客,论文以及代码开放在其个人Github主页,希望能给您带来帮助。同时,不足之初希望大家多多指出。

1.Abstraction

​ 现有的个性化联邦学习(PFL)方法大多基于以FedAvg为代表的基于参数交互的架构(PIA),由于设备和边缘服务器之间的大规模参数传输导致难以承受的通信负担。相比之下,基于Logits交互的架构(LIA)能够通过logits传输来更新模型参数,并且与PIA相比具有通信轻量级和允许异构设备端模型的优势。

然而,以前的 LIA 方法试图依靠公共数据集或增加除 logits (详见文末补充知识)之外的额外feature传输的通信开销来获得令人满意的性能。

为了解决这个困境,我们提出FedCache:一种知识缓存驱动的 PFL 架构。 通过深度预训练神经网络将客户端上的所有隐私样本编码为哈希值,从而以保护隐私的方式辨别样本之间的相关程度。它在服务器上保留一个知识缓存器,用于从与每个给定的设备上样本具有相似哈希值的R个近邻样本中获取个性化知识。在训练阶段,将集成蒸馏应用于设备上模型,以利用从服务器端知识缓存传输的个性化知识进行性能优化。对四个数据集的实证实验证明了 FedCache 与最先进的 PFL 方法的性能相当,通信效率提高了两个数量级以上。

2.Contribution

现有的PFL体系结构不能在系统性能(准确性)、资源效率(通信效率)和不依赖公共数据集之间实现良好的权衡,即便LIA比常用的PIA获得了显著减少的通信负担和容忍异构模型训练的优势。那么,如何设计个性化的联邦学习体系结构,在训练过程中只允许Logits传输,而不需要公共数据集,同时显著优于基于class粒度Logits交互的体系结构?

提出了FedCache,它是第一个基于 Sample-grained Logits 交互的架构(SLIA),且无需特征传输和公共数据集、保证令人满意的性能,同时符合 EI 中实际设备端异构的限制。

1.性能优于CLIA:实现了比CLIA更优越的性能(这是SLIA相较于CLIA与生俱来的优势。但是这两种方式的性能都要比PIA差。)

2.不依赖公共数据集:克服了先前 SLIA 的缺点(支持在训练期间无需公共数据集的帮助下传输样本级logits)

3.异构性:在客户端之间实现模型异构性(不要求模型结构相同,只要最后的logits维度相同即可,这是基于logits的知识蒸馏的优势)

4.高效通信(基于logits知识蒸馏的天然优势)

用原文的话说:

1.友好性:FedCache 是一种设备友好的架构,可以在训练期间仅在客户端和服务器之间传输小规模的集成知识,而无需公共数据集。同时,FedCache支持异构模型设备上的协同训练。

2.可扩展性:FedCache是一种针对大规模设备的高度可扩展的架构,因为它无需在服务器上保留繁琐的全局模型,并且还可以实现异步训练,有效减少服务器端计算量和客户端-服务器同步消耗。

3.高效性:在四个常见数据集上将 FedCache 与具有各种架构的最先进的 PFL 方法进行了比较。结果证实 FedCache 实现了与基准算法相当的性能,同时将通信效率提高了两个数量级

3.Introduction

联邦学习方法安全高效,打破信息孤岛。流行的 FL 方法不能保证模型在异构客户端上的泛化和适应。个性化联邦学习(PFL)被提出来实现客户端本地个性化训练需求和全局模型泛化目标之间的均衡。

现有PFL有两种架构:

在这里插入图片描述

3.1基于参数交互的PFL架构(PIA)

客户端定期将模型参数上传到服务器并下载聚合后的模型更新本地参数,基于PIA的PFL中客户机倾向于只上传部分模型参数进行聚合,以保持个性化能力
image.png
但是,对于通信资源有限的设备来说,大规模的参数传输是不可负担的;此外,PIA在聚合过程中需要设备上的模型架构具有很强的同质性,对于具有不同硬件约束的异构设备,这是难以实现的

3.2 基于Logits知识交互的PFL架构(LIA)

每个客户端对从服务器下载的全局Logits执行基于蒸馏的优化,在训练期间不传输参数,按Logits粒度可分为两类:

3.2.1类别粒度(CLIA)

对于客户端 k k k的某个类别 y i k y^k_i yik的样本,它的Logits输出应接近所有其他客户端上的该类别样本Logits平均值( F l , y i k F^{l,y^k_i} Fl,yik)的聚合值(除了客户端 k k k之外其它客户端平均值 F l , y i k F^{l,y^k_i} Fl,yik之和再闭上 K − 1 K-1 K1)

image.png
image.png
每个客户端只能学习C个类别对应的Logits,客户机学习到的额外服务器端信息很少,有性能瓶颈。

3.2.2样本粒度(SLIA)

设备上模型学习到的Logits知识与样本的数量有关,这种体系结构通常需要不可避免地引入公共数据集或增加通信开销。
(1)特征交换方式SLIA with Features Exchange (SLIA-FE)
模型分为特征提取器和分类器两部分,服务端部署一个大规模的分类器,服务端损失函数由上传的中间特征与Logits相对应的交叉熵损失和KL散度组合而成
image.png
image.png
虽不需要全部参数传输且支持异构模型,但参与者需要由相同的特征维度(模型层数可能不同,但是提取到的每一层特征的维度相同);此外,由于高分辨率图像和长序列数据的特征维数往往较高,因此对设备的特征传输开销仍然很大(可以理解为不需要传递所有参数,传递某些中间层即可,但是即便如此通信开销仍然很大)。此外,这些feature传递过程中也很容易受到反演攻击,不可避免地会损害用户的隐私。
(2)公共数据集方式
客户端k输出Logits应接近 所有客户端在公共数据集上Logits输出的均值
image.png
虽然进一步放宽了模型要求并且传输数据开销也更小,但该方法依赖一个公共数据集,其该数据分布要与本地私有数据分布接近时才能有比较不错的结果,但是数据分布本身就是敏感信息,现实中中心服务器获取不到。

4.知识缓存驱动的个性化联邦学习

我们提出在服务器上保留一个知识缓存,从而能够获取每个样本相关的个性化知识,具体来说,服务器端知识缓存跟踪示例的最新知识,并利用信息检索机制从缓存的知识中搜索每个样本最相关的知识;从其他客户搜索的知识伴随着可靠和有效的相关表示,并转移到样本来源的客户,进行基于建设性蒸馏的建设性优化。在此基础上,可以在服务器和客户端之间实现样本粒度的Logits交互,以确保设备上的模型学习到足够的个性化知识。

4.1 系统架构

image.png
服务端ensemble模块将从知识缓存中获取的知识结合起来,以获得要在客户端上需要蒸馏的个性化知识;
服务端知识缓存模块是我们设计的自组织知识存储结构,能够在服务器端获取每个客户端的相关知识;
客户端模型从局部数据中提取知识,并在蒸馏模块的指导下进行模型更新;
客户端编码器模块将私有数据编码为哈希值,用于初始化知识缓存

在初始化阶段,客户端上生成的哈希代码将一次性上传到服务器。然后在服务器端知识缓存中执行HNSW,目的是检索R个最相关的样本,以匹配通过哈希值的余弦相似度测量的每个样本。在训练阶段,在每次通信中,私有样本的Logits和索引被上传到服务器。然后,根据预先建立的相似度关系,获取每个样本的知识缓存中哈希相似度最高的R个最佳匹配知识,然后进行知识集成,然后将知识出传输到相应的客户端进行本地蒸馏。
由于蒸馏阶段只依赖于客户各自私有数据的高度相关知识,因此生成的模型具有局部适应性和较强个性化能力。

4.2 算法流程

服务器上的知识缓存使得任意本地样本能够在可控的计算复杂度内获取相关的知识,其中相应的哈希值样本提取的知识应该是r最近邻居的哈希值的原始样本。提供相关知识的样本的是与原样本哈希值的最相似R个样本。
为此设计四类键值对pairs,知识缓存有两个主要阶段:初始化和训练

4.2.1 初始化阶段

1.本地预训练网络提取样本hash值

模型采用在ImageNet上预训练的MobileNetV3,去掉最后一个全连接层,并将输出定义为哈希值。
image.png
2.所有客户端将本地样本的标签和哈希值上传到服务器上,以进行初始化知识缓存。

image.png

计算样本hash余弦相似度。
image.png
检索到的与每个样本索引相关的结果保存在IR中,以供后续访问
image.png

至此,初始化阶段结束,服务器得到四个键值对数据。

4.2.2 训练阶段

3.样本Logits输出上传到服务端进行更新
image.png

4.从服务端获取与样本相关的知识

image.png

image.png
5.知识集成聚合
image.png
6.训练
image.png

具体算法:

image.pngimage.png

值得注意的是,哈希值其实就是预训练模型的在所有客户端本地数据上的输出(去掉全连接层),它是一个向量。服务器对于所有样本中每一个类别,都可以得到该类别所有数据的距离(矢量间的距离HNSW)。当某个客户端请求索引 ( k , i ) (k,i) (k,i)的数据知识时,会找出该data最近邻的R个数据的logits进行平均。

哈希值只用来确定data间的距离,在初始化阶段确定后,后续训练阶段不改变。

5.实验

实验这里概述,深度预训练的编码器,我们采用在MmigeNet上预训练的MobileNetV3
300客户端 / 非独立同分布数据设置 / ResNet同构模型异构模型分别实验
同构模型

异构模型

6.问题与讨论

1.基于样本类别的优劣:设备端模型的蒸馏仅基于与本地样本相关的知识,无法学习到自己没有的样本的知识,虽然可以很大程度上提升个性化能力,但忽略了全局知识学习,模型泛化性能会受到影响。

2.哈希值其实就是预训练模型的输出特征,预训练模型可能size很大,预训练模型仍然需要在公开数据集上进行训练,虽然限制条件比代理数据集放宽了(不需要与本地数据分布相近),但是仍然需要公开数据集这种限制存在。哈希值本质上就是为了找到样本间的距离,然后用最近的R个样本来代表当前样本。 如果说能有一种更高效的衡量样本距离(相似性)的方法,那么可以极大地降低上述开销。

3.HNSW较为复杂限制相关参数实时更新,导致FedCache无法支持增量数据的PFL。
是否可以采用无监督的聚类? 初始时根据所有数据的哈希值向量聚合成K个类别,计算最近邻时从聚类后的簇中选择即可,缩小了寻找范围(或者根据观测结果,如果各类别图像的哈希值差别明显的话,放宽限制随机取出同一簇中的近邻。)。因为原始求距离的复杂度时 N l o g N ∗ O ( f ) Nlog^N * O(f) NlogNOf, f f f表示两个向量间的距离公式,当有一个新的数据加入时,就要全部打乱,从新求取所有样本间的距离了,效率很低。聚类后,当每一轮次新的数据加入时,从聚类中比较有限次数即可确定它的IR关系。

4.上传大量哈希值,也会产生大量的开销,它和样本数量与哈希向量维度成成比。

5.本文的异步训练并非异步联邦学习,指的是客户端如果需要随时查询logits,并且随时上传logtis,不用等别人。

补充知识:

logits(knowledge)就是最终的全连接层的输出(在未经过softmax之前的,就是未归一化的概率)

(161条消息) 知识蒸馏的梳理(侵删)基于logits的子类蒸馏_qq y的博客-CSDN博客

知识蒸馏的一个显著特点是,通过蒸馏的方法训练出的Student模型相比使用完全相同的模型结构和训练数据只使用Hard-target的训练方法得到的模型,拥有更好的泛化能力。

好模型的目标不是拟合训练数据,而是学习如何泛化到新的数据。所以蒸馏的目标是让student学习到teacher的泛化能力,理论上得到的结果会比单纯拟合训练数据的student要好。

  • 27
    点赞
  • 28
    收藏
    觉得还不错? 一键收藏
  • 13
    评论
评论 13
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值