一.介绍
使用非IID数据进行个性化跨思洛联盟学习的根本瓶颈是错误地认为一个全局模型可以适合所有客户。基于此,在本文中,作者通过一种新颖的消息传递机制来解决具有挑战性的个性化跨思洛联盟学习问题,该机制通过迭代鼓励相似客户进行更多协作,自适应地促进客户机之间的潜在成对协作。
二. 相关工作
全局联邦学习训练一个单一的全局模型以最小化所有客户数据的经验损失函数。然而当不同客户端之间的数据不是IID时,很难收敛到一个良好的全球模型。
本地特定方法通过对本地数据的自身处理来进行个性化。这里有几种方法可以进行定制。一是局部微调,使每个客户端的私有数据对全局模型进行微调,为客户机生成个性化模型。同样的,元学习方法可以通过根据客户端的本地数据训练全局模型来定制个性化模型。模型混合方法通过将全球模型与客户的潜在本地模型结合,为每个客户定制。还有的通过梯度调整,以纠正客户在个性化模型和全局模型之间的漂移。
三. 个性化联邦学习问题
考虑有
m
m
m个客户端
C
1
,
.
.
.
,
C
m
C_1,...,C_m
C1,...,Cm,有相同类型的模型
M
\mathcal{M}
M,只是其中的参数
w
1
,
w
2
,
.
.
.
,
w
m
w_1,w_2,...,w_m
w1,w2,...,wm不同。使用
M
(
w
i
)
\mathcal{M}(w_i)
M(wi)和
D
i
D_i
Di表示客户端
C
i
C_i
Ci上的模型和数据。使用
V
i
\mathcal{V_i}
Vi和
V
∗
(
i
)
\mathcal{V}^*(i)
V∗(i)表示在客户端i上的表现以及最佳表现。个性化联邦学习旨在协同利用数据集训练个性化模型,也就是希望更新
M
(
w
1
)
,
.
.
.
,
M
(
w
m
)
\mathcal{M}(w_1),...,\mathcal{M}(w_m)
M(w1),...,M(wm)使结果接近近
V
∗
(
1
)
,
.
.
.
,
V
∗
(
m
)
\mathcal{V}^*(1),...,\mathcal{V}^*(m)
V∗(1),...,V∗(m)。
更具体的说,定义
F
i
:
R
d
→
R
F_i:\mathbb{R}^d\rightarrow\mathbb{R}
Fi:Rd→R为训练函数,将个性化的问题描述为:
min
w
{
G
(
W
)
:
=
∑
i
=
1
m
F
i
(
w
i
)
+
λ
∑
i
<
j
m
A
(
∣
∣
w
i
−
w
j
∣
∣
2
)
}
(1)
\min_w \{ \mathcal{G}(W):=\sum^m_{i=1}F_i(w_i)+\lambda\sum_{i<j}^mA(||w_i-w_j||^2)\} \tag1
wmin{G(W):=i=1∑mFi(wi)+λi<j∑mA(∣∣wi−wj∣∣2)}(1)
其中
W
=
[
w
1
,
.
.
.
,
w
m
]
W=[w_1,...,w_m]
W=[w1,...,wm],
λ
>
0
\lambda>0
λ>0是一个正则化参数
(1)中的第一项
∑
i
=
1
m
F
i
(
w
i
)
\sum^m_{i=1}F_i(w_i)
∑i=1mFi(wi)表示为对所有个性化模型训练损失的总和。第二项提高了客户之间的协作效率通过一个注意力机制的函数(attention-inducing)
A
(
∣
∣
w
i
−
w
j
∣
∣
2
)
A(||w_i-w_j||^2)
A(∣∣wi−wj∣∣2)。
现在我们来看一下这个注意力机制函数:
定义1:
A
(
∣
∣
w
i
−
w
j
∣
∣
2
)
A(||w_i-w_j||^2)
A(∣∣wi−wj∣∣2)满足以下定义:
- A A A是一个在 [ 0 , ∞ ) [0,\infty) [0,∞)的递增凹函数
- A A A在 [ 0 , ∞ ) [0,\infty) [0,∞)上连续可微
- 对于 A A A的导数 A ′ A' A′, lim t → 0 + A ′ ( t ) \lim_{t\rightarrow0^+}A'(t) limt→0+A′(t)为有限的
函数 A ( ∣ ∣ w i − w j ∣ ∣ 2 ) A(||w_i-w_j||^2) A(∣∣wi−wj∣∣2)以非线性方式测量 w i w_i wi和 w j w_j wj的差异。一个典型的例子是: A ( ∣ ∣ w i − w j ∣ ∣ 2 ) = 1 − e − ∣ ∣ w i − w j ∣ ∣ 2 / σ A(||w_i-w_j||^2)=1-e^{-||w_i-w_j||^2/\sigma} A(∣∣wi−wj∣∣2)=1−e−∣∣wi−wj∣∣2/σ。
四. 联邦消息传递(FedAMP)
一般方法
定义
F
(
W
)
:
=
∑
i
=
1
m
F
i
(
w
i
)
\mathcal{F}(W):=\sum^m_{i=1}F_i(w_i)
F(W):=∑i=1mFi(wi)以及
A
(
W
)
:
=
∑
i
<
j
m
A
(
∣
∣
w
i
−
w
j
∣
∣
2
)
\mathcal{A}(W):=\sum^m_{i<j}A(||w_i-w_j||^2)
A(W):=∑i<jmA(∣∣wi−wj∣∣2)作为(1)的第一项和第二项,因此变为:
min
w
{
G
(
W
)
:
=
F
(
W
)
+
λ
A
(
W
)
}
(2)
\min_w\{\mathcal{G}(W):=\mathcal{F}(W)+\lambda \mathcal{A}(W)\} \tag2
wmin{G(W):=F(W)+λA(W)}(2)
在第k轮的时候,我们首先更新
A
(
W
)
\mathcal{A}(W)
A(W):
U
k
=
W
k
−
1
−
α
k
∇
A
(
W
k
−
1
)
(3)
U^k=W^{k-1}-\alpha_k\nabla\mathcal{A}(W^{k-1}) \tag 3
Uk=Wk−1−αk∇A(Wk−1)(3)
其中
α
\alpha
α为梯度的步长。之后,我们利用更新的
U
k
U^k
Uk去更新W:
W
k
=
arg min
w
F
(
W
)
+
λ
2
α
k
∣
∣
W
−
U
k
∣
∣
2
(4)
W^k = \argmin_w\mathcal{F}(W)+\frac{\lambda}{2\alpha_k}||W-U^k||^2 \tag4
Wk=wargminF(W)+2αkλ∣∣W−Uk∣∣2(4)
以这种方法更新参数。
FedAMP
通过将所有客户的私人培训数据合并为培训数据,可以轻松实现上述通用方法。为了在不侵犯客户端数据隐私的情况下执行个性化联合学习,我们开发了FedAMP,通过在云服务器上为每个客户端维护个性化的云模型,在客户端-服务器框架中实现通用方法的优化步骤,以及在个性化模型和个性化云模型之间传递加权模型聚合消息。
和一般方法一致,FedAMP首先优化
A
(
W
)
\mathcal{A}(W)
A(W),然后将
U
k
U^k
Uk传到服务端。
让
U
k
=
[
u
1
k
,
.
.
.
,
u
m
k
]
U^k=[u_1^k,...,u_m^k]
Uk=[u1k,...,umk],将(3)式改写为:
u
i
k
=
(
1
−
α
k
∑
j
≠
i
m
A
′
(
∣
∣
w
i
k
−
1
−
w
j
k
−
1
∣
∣
2
)
)
∗
w
i
k
−
1
+
α
k
∑
j
≠
i
m
A
′
(
∣
∣
w
i
k
−
1
−
w
j
k
−
1
∣
∣
2
)
)
∗
w
j
k
−
1
=
ξ
i
,
1
w
1
k
−
1
+
.
.
.
+
ξ
i
,
m
w
m
k
−
1
(5)
\begin{aligned} u_i^k&=(1-\alpha_k\sum^m_{j \not = i}A'(||w_i^{k-1}-w_j^{k-1}||^2))*w_i^{k-1} \\ &+\alpha_k\sum^m_{j \not = i}A'(||w_i^{k-1}-w_j^{k-1}||^2))*w_j^{k-1} \\ &=\xi_{i,1}w_1^{k-1}+...+\xi_{i,m}w_m^{k-1} \end{aligned} \tag 5
uik=(1−αkj=i∑mA′(∣∣wik−1−wjk−1∣∣2))∗wik−1+αkj=i∑mA′(∣∣wik−1−wjk−1∣∣2))∗wjk−1=ξi,1w1k−1+...+ξi,mwmk−1(5)
经常选取一个小一点的
α
k
\alpha_k
αk保证所有的线性组合权重非负。因为
ξ
i
,
1
+
.
.
.
+
ξ
i
,
m
=
1
\xi_{i,1}+...+\xi_{i,m}=1
ξi,1+...+ξi,m=1,
u
k
u_k
uk为客户端个性化模型参数集的凸组合。
服务器计算(5)后,将参数返回给客户端,客户端进行如下更新:
w
i
k
=
arg min
w
∈
R
d
F
i
(
w
)
+
λ
2
α
∣
∣
w
−
u
i
k
∣
∣
2
(6)
w_i^k=\argmin_{w\in \R^d}F_i(w)+\frac{\lambda}{2\alpha}||w-u_i^k||^2 \tag 6
wik=w∈RdargminFi(w)+2αλ∣∣w−uik∣∣2(6)
具体算法如下图所示:
FedAMP合作(解释)
FedAMP自适应地促进相似客户机之间的协作,因为专注的消息传递机制迭代地鼓励相似客户机在个性化联合学习过程中更多地相互协作。
经过(5)的计算,可以算出每一个客户端对服务端的比重,也就是说当两个客户端参数很接近时,随着训练他们也会越来越接近,因此会鼓励模型参数相近的客户端多进行协作,形成正反馈。
这次代码只找到了在华为云上的部分,关键代码被集成到了Moxing的板块,无法查到,因此就不展示。