联邦学习(FL)近来引起了高度关注,因为它允许客户端在保持训练数据本地的情况下协同训练模型。然而,由于本地数据分布的固有异质性,训练出的模型通常在每个客户端上表现不佳。为解决这个问题,聚类式FL应运而生,通过将具有相似数据分布的客户端进行聚类来解决问题。然而,这些基于模型的聚类方法往往表现不佳且成本高昂。在本研究中,我们提出了一种基于分布相似性的聚类式联邦学习框架FedDSMIC,通过检测模型对训练数据的记忆来检测客户端级别的基础数据分布,从而对客户端进行聚类。此外,我们将数据分布的假设扩展到更现实的聚类结构。通过学习中心模型,获取集群中的共同数据属性作为良好的初始点。然后,每个聚类中的客户端通过从初始点开始执行一步梯度下降来获得更个性化的模型。在真实世界数据集上的实证评估表明,FedDSMIC优于流行的最新联邦学习算法,同时保持最低的通信开销。
Intro
传统的FL对单个客户端效果往往不好,因为数据异构。客户往往需求不同。一般解决方法两方面:全局模型个性化或则个性化训练模型。
到本地微调降级成纯本地学习,后者直接获得了个性化模型。因此可以划分为簇。
部分方法在服务端根据余弦相似度或者L2范数上模型参数距离进行分类,但这些方法也通常因为高维度和变换不变性而失败。另一些方法随机初始化k个全局模型代表k个簇,每一个客户端挑一个在自己数据集上损失最小的作为分类依据。虽然提升了效果但是多个模型增加了通讯开销。
正确的想法是一个簇里的客户端仍然是松散分开的。传统方法没有充分利用客户端的分布信息,严重影响聚类准确度。
本文提供了一个新颖的聚类多任务联邦学习框架FedDSMIC,基于客户端之间的概率输出分布Kullback-Leibler分歧,对应服务器上指示器样本,准确检测客户端级别的数据分布相似性。受MAML启发本方法目标在于给一个簇中客户端学习一个好的起始点。
贡献:
-
动态聚类联邦学习框架,检测客户端分布,提升聚类准确性
-
一个聚类中的一个模型的不好之处,和两步学习法的好处
-
真实世界模型测试,更少消耗达到更好效果
相关工作
传统联邦学习没直接考虑本地模型表现;个性化联邦学习中全局模型个性化效果往往因为异构而不好。基本上把PFL(个性化联邦学习)分成两种:模型个性化和直接学习个性化模型。聚类+元学习主要属于模型个性化和一部分的学习个性化模型(先聚类后类间学习)
框架和问题定义
有类结构,但类之间仍然不完全相同。
问题定义很简单:最小化所有类中样本损失之和,之后保证类之间客户端相似度最高。u是一个指示变量,表示i在不在集群c中。
有三个变量要学习:类本地模型,客户端本地模型,和集群分配策略。因此采用交替优化策略,固定一个然后解决另一个。
算法:每一轮收集local model和cluster model,通过在一批“指示样本”上计算预测结果后,用KL散度判断相似度。之后更新cluster assignment。其实就是K means
由于目标是一次梯度下降,所以对损失函数求导。但是这里讲的很多东西都是已经有过的内容而非作者的原创内容。
本地模型更新过程:更新tau次,每次更新使用2个batch的数据:基础模型复制一份,复制的模型先更新一次得到一个中间模型,之后在中间模型的基础上再更新一次计算得到梯度,把这个梯度计算到原本的基础模型上。这里就有两个学习率。
实验
使用了五个数据集:MNIST,EMNIST,FEMNIST,CIFAR10和CIFAR100.
-
MNIST:一个MCLR模型,多正态分布逻辑斯蒂回归
-
EMNIST和FEMNIST:CNN模型
-
MobileNet-v2:CIFAR10和CIFAR100
baseline:
-
local:本地数据集训练
-
FedAvg,FedProx
-
IFCA:CFL方法,模型选择经验损失最小的模型
-
FedSEM:CFL方法,使用l2范数最小化本地和中心模型
-
PerFedAvg:一个模型,给所有人用
-
FedDS:本方法,但是不使用元学习
异构性:
-
FEMNIST:天生异构,不同用户自己不同
-
剩下的:将所有标签分给C个集群,之后用狄利克雷分布确定一个集群内各个样本的比例。
indicator sample是事先抽取的
cluster=3,全部客户端参与训练,MNIST和FEMNIST300轮,其余200轮,lr=0.01, batchsize=32,更新次数20,momentum=0.9。注意作者没有说用了多少客户端
Evaluation:local test dataset上测试。FedAvg和FedProx测量本地客户端上test set的准确率。IFCA和FedSEM测试cluster的model,之后根据data size的大小给出加权准确率。给出了5次测试平均值。
结果:
-
效果好速度快,手写数字提升3-5%,cifar10和100提升超过10%
-
IFCA可以和FedDS提供相似准确性但是收敛速度慢。作者认为充分利用了数据分布。
-
FedSEM表现最差,没有正确分类,和FedAvg类似
-
作者认为本方法很Fair(最低的准确率比较高)
-
在未见客户端上的表现,训练之用80%,训练完后用剩下的测试
-
通信消耗,在所有情况下消耗最低且准确率最高。IFCA要求所有模型给所有客户端导致高消耗
敏感度分析
-
高异构情况下本方法好,iid时perfedavg好因为它利用更多样本
-
indicator sample越多表现比较好,有边际递减
-
检测了活跃客户端比例对收敛的影响。0.2,0.5和1.0
-
通信效率,检测了本地更新 tao 和 本地轮次 E 在 CIFAR10 和 CIFAR100 上的效果。需要注意的是,作者实验部分根据不同数据集选择了不同的训练策略。简单的数据集用tau,复杂的数据集用了E
消融测试:只用了聚类就已经效果好了,用了FML更好
进一步讨论,用不同方法训练出来的模型在CIFAR10上给出的特征表示。可以看见local train出来的模型分得很开但是并不能适应generalization;在数据不多的客户端上,local data不够训练一个好的模型,则local train表现就不够好。使用了t-SNE方法展示图片。
t-SNE(t-distributed Stochastic Neighbor Embedding)是一种非常流行的机器学习算法,用于高维数据的可视化。
总结
本工作是一个相对比较常规的工作,但是写作清晰,可见是出自老手。同时采用了很多方法隐蔽地展现了本方法的好处,没有体现其短处。
聚类+元学习两个一结合,加上弱化了indicator sample这个东西的问题让工作看起来非常的promising,同时规避了在时间上的弱势对比。
我个人很不喜欢聚类,因为大部分工作需要提前规定聚类个数,能够自动分类的其条件也比较苛刻。加之各种各样的限制,还有高开销和根本不可能做到的大规模化。