先上结论:本文提出了一种新颖的方法Fed Per,使用现有的联邦学习方法,将深度学习模型视为基础+个性化层,以协作方式对基础层进行训练,在本地进行个性化层训练,用于捕获联邦学习设置中用户的个性化方面。
结果表明,具有基础+个性化层的模型有助于对抗统计异质性的不利影响。在FLICKR - AES和CIFAR数据集上的实验结果表明了FedAvg的无效性和FedPer在建模个性化任务方面的有效性。
背景:
一个单词:statistical heterogeneity 统计学异质性
不同设备数据可能是异构的,本文提出了一种用于深度前馈神经网络联邦训练的基层+个性化层方法Fed Per,可以对抗统计异构性带来的不良影响。
图为提出的个性化联合方法示意图,主要的思想为基本层+个性化层
- 所有的客户设备共享基本层(蓝色),共享相同的权重,这些权重来自参数服务器,所以基本层和服务器也是共享的
- 由于数据分布的不同,不同客户设备具有不同的个性化顶层,可以潜在地适应单个数据。
我们研究了个性化设置作为统计异质性来源对深度前馈神经网络联邦学习的影响。
个性化是机器学习的一个关键应用,并且可能实现,因为从原始用户数据中捕获的用户偏好不同。
在边缘设备属于用户的联合设置中,这必然意味着用于个性化的数据在统计上是异构的。
而将联邦平均扩展到深度神经网络模型或协同过滤无法解决的问题的效果并不明显。个性化联邦学习的正确方法是一个非常重要的问题,研究界对此几乎没有触及。
挑战:
特征提取,相同数据不同标签,这超过了联邦学习学习一个全局模型,并在每个客户端上有效地在本地复制的模式范围。数据多,用户设备多,差异大,难以克服统计学异质性的不利影响。用户数据少时,它们共享需要数据重叠,而像个性化图像美学和个性化高光检测,使用数据集重叠是没有用的
本文贡献:
(1)将深度学习模型视为基本+个性化层来捕获联邦学习中的个性化方面
(2)我们的训练算法包括由联邦平均(或其某些变体)训练的基础层和仅从具有随机梯度下降(或其某些变体)的本地数据训练的个性化层
(3)免于联邦平均( FedAvg )过程的个性化层可以帮助对抗统计异质性的不利影响
模型和算法设置:
模型:根据个性化方法联合示意图,所有用户设备共享相同的基础层,并具有独特的个性化层,构成深度前馈神经网络模型。
模型假设:
(1)假设权重张量WPj在第j个设备上捕获个性化的所有方面。
(2)任何客户端的数据集在全局聚合中都不会发生变化
(3)批大小b和迭代次数e在客户端和全局聚合之间是不变的
(4)每个客户端使用SGD在全局聚合之间更新
(5)在整个训练过程中,所有N个用户设备都处于活动状态。
参数和公式:
意义 | 符号 |
每个客户设备上的基本层数目 | |
每个客户设备上的个性化层数目 | |
用户设备总数 | N |
基本层权重矩阵 |
|
基本层权重矩阵对应的向量值激活函数 | |
第j个用户设备的个性化层权重矩阵,j ∈ {1, 2, . . . , N } |
|
个性化层权重矩阵对应的向量值激活函数 | |
在第j个设备上进行的前向传递操作,即神经网络的forward pass。原始公式看上去十分复杂,可以不用管,只关心输入和输出即可,即客户端的样本首先经过基本层,然后再经过个性化层,最后得到输出 | 可简化为 |
需要优化的损失函数(即所有客户端损失的均值) 其中L( · , ·)表示所有设备共有的每个样本损失函数 学习的目标是通过权重张量 | |
所有设备通用的每个样本损失函数 (即其中第j个设备上的损失) | |
数据集的批次大小 | b |
数据集的迭代次数 | e或k |
学习率 | |
算法设计:
本算法依赖于随机梯度下降( SGD )作为子程序。
最小化经验风险函数(average personalized population risk function )的标准公式要求指定以下内容:
(a).将由SGD更新的决策变量及其初始值
(b).划分数据集的批次大小,数据集的迭代次数
(c).学习率
第j个客户端的步骤 |
|
1:客户端初始化自己的个性化层权重 |
2:这里我怀疑论文作者写错了,文中并没有 |
3:文中写道用上标k表示迭代轮数,所以这里是迭代次数的意思 |
4:这里 |
5:为了既实现个性化又联合联邦学习的优势,使用了全局聚合的参数,也使用了设备本地个性化的权重矩阵,并本设备的学习率 |
6:并且只传输 |
7:在两个服务器之间往复,直到迭代完毕 |
FedPer的服务器组件的步骤 |
|
1:基本层的权重矩阵初始化 |
2:接受各设备传来的可用样本数 |
3:为各设备奉上基本层权重数据 |
4:在每轮迭代中: |
5:接受被各设备(或第j个)更新过的基本层权重矩阵 |
6:使用 |
7:共享基本层权重数据 |
8:结束 |
FedPer的服务器组件的步骤在算法2中详述,第j个客户端的步骤在算法1中描述。服务器使用基于FedAvg的方法在全局上训练基础层,而每个客户端使用SGD风格算法在本地更新其基础层和个性化层(在连续的全局聚集之间)。