在进行联邦学习中解决异质性问题的调研,这篇文章发表于2022的AAAI,使用原型来解决这部分问题,同时还使用“模型异构“的概念(之前一般见到的都是系统异构)。
一、阅读目标
-
了解使用原型学习解决异质性的思路,以及可以借鉴的方向
-
区分模型异构与系统异构两个概念,并总结
二、问题回答
- 构建以类别为单位的原型,服务器端通过原型的传递代替模型的聚合,并将聚合的原型返回客户端用于训练
- 我认为二者是从属关系,模型异构应该属于系统异构的一个方面,但现在没有官方定义,大家仍然混用
三、摘要
动机:客户端之间的知识聚合发生在梯度空间时,会阻碍优化收敛和泛化性能。比如,客户端的数据分布、网络延迟、输入/输出空间和模型结构等都会轻易导致他们本地梯度之间的不对齐
解决方案:提出联邦原型学习框架,客户端与服务器之间可以传递抽象类原型,而不是梯度。具体来说,此框架首先从不同的客户端收集本地原型进行聚合得到全局原型,然后将全局原型传回给客户端,来规范本地模型的训练,使本地原型不断接近于相应的全局原型。
主要论点:(作者没说,我凭借动机和解决方案总结一下)客户端之间的异质性会轻易导致他们本地梯度之间的不对齐,但抽象类原型的传递可以避免这种不对齐
注:1)整体想法听起来很像Model-Contrastive Federated Learning文章的思路(模型简称为MOON),MOON模型的设计是,在模型参数的传递之外,客户端和服务器之间传递模型hidden layer得到的特征向量,构造本地特征向量和全局特征向量之间的对比损失,令此对比损失参与客户端的训练过程,使得客户端的本地特征向量不断逼近全局特征向量 2)相比于MOON,FedProto的不同之处在于,他只传递原型,而不传递模型参数(由于知识的匮乏,不太懂这是怎么完成的)3)主要论点是怎么完成的(为什么原型的传递可以避免客户端之间本来会出现的不对齐)
四、引入
传统的联邦学习包括两大挑战:数据和模型上的异质性;
前者称为统计异质性,解决方案包括:1)为不同的本地分布保持多个全局模型,比如聚类FL的工作(暂时没读过这方面的文章,从命名上看,是说对所有的本地模型进行聚类,相当于为相近的本地模型保留异质性) 2)统计异质性的第二个解决方案是,个性化的FL,通过利用全局和本地的信息为每个客户端生成个性化的模型。但是这些方法都是基于梯度的聚合,会导致高通信成本,以及对同质本地模型的依赖
后者称为模型异质性(最初这点被称为系统异质性,这两个叫法,我prefer the latter,因为系统异质性应该包括模型异质性), 范畴包括客户端之间不同的硬件和计算能力;解决方案为:基于知识蒸馏的FL,通过向学生模型传递教师模型的知识,来平衡不同模型之间的异构性,这个方案的弊端有二:1)需要一个额外的公共数据集去对齐学生和教师模型之间的输出,因此增加了计算成本 2)当公共数据集和client的本地数据集之间的分布差异增加时,整个模型的性能会退化
尽管数据和模型是异构的。客户端仍然可以通过交换表示信息来共享知识。举个例子,人们对于”狗“的认知(原型)是不同的,这是由于人之间的经历(数据)不同和处理信息的方式(模型)不同造成的,但我们仍然可以通过交换这种认知来获得一个更全面的认知。(我对这一段的理解是,统计异构和模型异构都属于过程的范畴,而原型是一个结果,或者说是目的)
于是提出了一个基于原型聚合的框架,这种方式对异质性的联邦场景有着巨大的潜力,表现在,任何联邦场景(无论是数据异构还是模型异构),均可以使用这种方式完成模型的聚合。方法是,每个抽象原型都可以通过由属于同一类别的观察样本转变为的平均表示来表示一个类
注:1)关于统计异质性和模型异质性的这三种解决方案,我的了解都不深,需要再重点看看 2)看完摘要,感觉这篇文章的立意确实很强,大多数联邦学习文章,都是从如何解决异构性的角度出发,而这篇文章是说,你异构就异构吧,我从别的地方下手,另辟蹊径解决问题 3)下面重点就是看他是使用何种方法得到这种原型表示了,以及这种原型表示能不能像他说得那样有着类似人类”认知“的强大表达能力(如果没有就是高开低走了)
五、相关工作
一般来说都不整理相关工作的,但最近正好要做关于异质性的调研,正好来研究下
统计异质性的解决方案:1)FedProx为代表的用本地正则化项来优化模型 2)个性化模型(包括元训练策略) 3)对本地模型聚类 4)自监督学习策略
模型异质性的解决方案:1)基于知识蒸馏的FL 2)神经网络搜索与FL的结合 3)共同学习平台的 4)基于功能的神经匹配
原型学习:原型是指多个特征的平均。本文使用原型来表示一个类,并在异质FL的设定中使用原型聚合
注:1)超级大问题:这个原型是用来表示一个类的特征的平均,那么跟原数据的直接概率分布有什么不同? 2)不过就是传递网络中某一层之后的特征平均,为什么就能替代模型参数?这表达能力完全不是一个级别吧?而且server该如何聚合呢?client拿到server返回的数据又该如何利用呢? 3)没有类别概念的任务,是不是可以直接对client进行原型学习,然后在server端进行原型聚合?
六、方法
1. FedAvg
FedAvg中的目标函数:
a
r
g
m
i
n
w
∑
i
=
1
m
∣
D
i
∣
N
L
S
(
F
(
w
;
x
)
,
y
)
arg min_w \sum_{i=1}^m \frac{|D_i|}{N}L_S(F(w;x),y)
argminw∑i=1mN∣Di∣LS(F(w;x),y)
D
i
D_i
Di表示第i个客户端的数据集,
N
N
N为整体数据量,
F
F
F为共享模型,
L
S
L_S
LS表示损失函数,此目标函数用来最小化加权的客户端损失函数
2. 真实联邦场景
真实的联邦学习场景,会存在统计异质性和模型异质性,则目标函数变为:
a
r
g
m
i
n
w
1
,
w
2
,
.
.
.
,
w
m
∑
i
=
1
m
∣
D
i
∣
N
L
S
(
F
i
(
w
i
;
x
)
,
y
)
arg min_{w_1,w_2, ..., w_m} \sum_{i=1}^m \frac{|D_i|}{N}L_S(F_i(w_i;x),y)
argminw1,w2,...,wm∑i=1mN∣Di∣LS(Fi(wi;x),y)
模型异质性会导致目标函数的变化,统计异质性则不会;但统计异质性会对联邦优化的结果造成影响
3. 基于原型的聚合
在基于原型的问题中,仍然存在统计异质性和模型异质性。但不同的模型,都可以分成两部分:表示层(Representation layers) f i ( ϕ i ; x ) f_i(\phi_i;x) fi(ϕi;x)和决策层(Decision layers),后者通常指示网络的最后一层,那么表示层Representation layers则代表除了最后一层之外的其他层;
原型的表示是基于类别的,其中第
i
i
i个client中的第
j
j
j类别的原型表示为:
C
i
(
j
)
=
1
∣
D
i
,
j
∣
∑
(
x
,
y
)
∈
D
i
.
j
f
i
(
ϕ
i
;
x
)
C_i^{(j)} = \frac{1}{|D_{i,j}|}\sum_{(x,y) \in D_{i.j}}f_i(\phi_i;x)
Ci(j)=∣Di,j∣1∑(x,y)∈Di.jfi(ϕi;x)
client要对每个类别都聚合一个原型,假设共有
E
E
E类,则每个client向server传递
E
E
E个原型
目标函数可以写为:
a
r
g
m
i
n
{
C
ˉ
(
j
)
}
j
=
1
∣
C
∣
∑
i
=
1
m
∣
D
i
∣
N
L
S
(
F
i
(
w
i
;
x
)
,
y
)
+
λ
∑
j
=
1
∣
C
∣
∑
i
=
1
m
∣
D
i
.
j
∣
N
j
L
R
(
C
ˉ
i
(
j
)
,
C
i
(
j
)
)
arg min_{\{\bar C^{(j)}\}_{j=1}^{|C|}} \sum_{i=1}^m \frac{|D_i|}{N}L_S(F_i(w_i;x),y) + \lambda \sum_{j=1}^{|C|} \sum_{i=1}^m \frac{|D_{i.j}|}{N_j}L_R(\bar C_i^{(j)},C_i^{(j)})
argmin{Cˉ(j)}j=1∣C∣∑i=1mN∣Di∣LS(Fi(wi;x),y)+λ∑j=1∣C∣∑i=1mNj∣Di.j∣LR(Cˉi(j),Ci(j))
C
i
(
j
)
C_i^{(j)}
Ci(j)是第
i
i
i个client的第
j
j
j类别的原型,则
C
ˉ
i
(
j
)
\bar C_i^{(j)}
Cˉi(j)是server端对所有client的第
j
j
j类原型的聚合原型(
i
i
i下标单单指示不同client的训练,但每一轮中不同client的
C
ˉ
i
(
j
)
\bar C_i^{(j)}
Cˉi(j)是相同的)
client端,保持
C
ˉ
i
(
j
)
\bar C_i^{(j)}
Cˉi(j)不变,优化
w
i
w_i
wi和
C
i
(
j
)
C_i^{(j)}
Ci(j)
server端,通过下列公式聚合全局原型:
C
ˉ
(
j
)
=
1
∣
N
j
∣
∑
i
∈
N
j
∣
D
i
,
j
∣
N
j
C
i
(
j
)
\bar C^{(j)} = \frac{1}{|N_j|}\sum_{i \in N_j}\frac{|D_{i,j}|}{N_j}C_i^{(j)}
Cˉ(j)=∣Nj∣1∑i∈NjNj∣Di,j∣Ci(j)
其中,
N
j
N_j
Nj表示所有包含
j
j
j类的client的集合
模型测试阶段,如果出现一个新的client,那么首先使用预训练的模型,比如ImageNet上训练的ResNet18对表示层进行初始化,随机化决策层,然后使用这个公式来调整参数:
a
r
g
m
i
n
j
∣
∣
f
(
ϕ
;
x
)
−
C
(
j
)
∣
∣
2
argmin_j||f(\phi;x)-C^{(j)}||_2
argminj∣∣f(ϕ;x)−C(j)∣∣2
表示通过最小化表示向量和聚合原型之间的
L
2
L2
L2距离,然后就可以令新的client进行预测了