一、概要
本文算法提出的背景是常规的联邦学习方法对cross-silo问题低效,并提高了安全风险,因为在每一轮迭代中都需要交换梯度更新信息。本文提出的FedBCD算法允许多方在通信之前进行多次本地更新,从而减少通信量。
二、关键算法
问题定义
K
K
K个参与方,
N
N
N个数据样本
D
≜
{
ξ
i
}
i
=
1
N
D\triangleq\{\xi_i\}_{i=1}^N
D≜{ξi}i=1N,其中
ξ
≜
(
x
,
y
)
\xi\triangleq(\mathbf{x},y)
ξ≜(x,y)表示为特征和label。特征向量
x
i
∈
R
1
×
d
\mathbf{x}_i\in \R^{1\times d}
xi∈R1×d分布在
K
K
K个参与方中
{
x
i
,
k
∈
R
1
×
d
k
}
k
=
1
K
\{\mathbf{x}_{i,k}\in \R^{1\times d_k}\}_{k=1}^K
{xi,k∈R1×dk}k=1K,
d
k
d_k
dk表示参与方的特征维度。有一方参与方拥有label,假设为参与方
K
K
K。则联邦数据集可以表示为:
D
k
≜
{
x
i
,
k
}
i
=
1
N
,
k
∈
[
K
−
1
]
;
D
K
≜
{
x
i
,
K
,
y
i
,
K
}
i
=
1
N
D_k\triangleq\{\mathbf{x}_{i,k}\}_{i=1}^N,k\in [K-1];D_K\triangleq \{\mathbf{x}_{i,K},y_{i,K}\}_{i=1}^N
Dk≜{xi,k}i=1N,k∈[K−1];DK≜{xi,K,yi,K}i=1N。联邦训练模型:
min
Θ
L
(
Θ
,
D
)
≜
1
N
∑
i
=
1
N
f
(
θ
1
,
.
.
.
,
θ
K
;
ξ
i
)
+
λ
∑
k
=
1
K
γ
(
θ
k
)
(
1
)
\min_{\Theta}L(\Theta, D)\triangleq \frac{1}{N}\sum_{i=1}^Nf(\theta_1,...,\theta_K;\xi_i)+\lambda\sum_{k=1}^K\gamma(\theta_k) \quad(1)
ΘminL(Θ,D)≜N1i=1∑Nf(θ1,...,θK;ξi)+λk=1∑Kγ(θk)(1)
其中
θ
k
∈
R
d
k
\theta_k \in \R^{d_k}
θk∈Rdk表示第k个参与方的模型参数。
f
(
⋅
)
f(·)
f(⋅)和
γ
(
⋅
)
\gamma(·)
γ(⋅)表示损失函数和正则器。
λ
\lambda
λ表示正则器的超参数。对于广义模型,如线性回归、逻辑回归、支持向量机等的损失函数可表示为
f
(
θ
1
,
.
.
.
,
θ
K
;
ξ
i
)
=
f
(
∑
k
=
1
K
x
i
,
k
θ
k
,
y
i
,
K
)
(
2
)
f(\theta_1,...,\theta_K;\xi_i)=f(\sum_{k=1}^K\mathbf{x}_{i,k}\theta_k,y_{i,K}) \quad(2)
f(θ1,...,θK;ξi)=f(k=1∑Kxi,kθk,yi,K)(2)
联邦学习的目标就是每个参与方在不泄露本地数据和模型参数的前提下训练出最优的模型参数
θ
i
\theta_i
θi。
FedBCD算法
假设小批量
S
⊂
D
S\sub D
S⊂D则随机部分梯度
g
k
(
Θ
,
S
)
≜
∇
k
f
(
Θ
;
S
)
+
λ
∇
γ
(
θ
k
)
(
3
)
g_k(\Theta, S)\triangleq \nabla_kf(\Theta;S)+\lambda\nabla_{\gamma}(\theta_k)\quad(3)
gk(Θ,S)≜∇kf(Θ;S)+λ∇γ(θk)(3)
让
H
i
k
≜
x
i
,
k
θ
k
H_i^k\triangleq \mathbf{x}_{i,k}\theta_k
Hik≜xi,kθk,
H
i
≜
∑
k
=
1
K
H
i
k
H_i\triangleq\sum_{k=1}^KH_i^k
Hi≜∑k=1KHik,因此对于损失函数(2)有
∇
k
f
(
Θ
;
S
)
=
1
S
∑
ξ
i
∈
S
∂
f
(
H
i
,
y
i
,
K
)
∂
H
i
(
x
i
,
k
)
T
(
4
)
\nabla_kf(\Theta;S)=\frac{1}{S}\sum_{\xi_i\in S}\frac{\partial f(H_i,y_{i,K})}{\partial H_i}(\mathbf{x}_{i,k})^T\quad(4)
∇kf(Θ;S)=S1ξi∈S∑∂Hi∂f(Hi,yi,K)(xi,k)T(4)
为了计算本地
∇
k
f
(
Θ
;
S
)
\nabla_kf(\Theta;S)
∇kf(Θ;S)每个参与方
k
∈
[
K
−
1
]
k\in[K-1]
k∈[K−1]需要发送
I
S
k
,
K
≜
{
H
i
k
}
i
∈
S
I_S^{k,K}\triangleq\{H_i^k\}_{i\in S}
ISk,K≜{Hik}i∈S给拥有label的一方
K
K
K,有参与方
K
K
K计算
I
S
K
,
q
≜
{
∂
f
(
H
i
,
y
i
,
K
)
∂
H
i
}
i
∈
S
I_S^{K,q}\triangleq\{\frac{\partial f(H_i,y_{i,K})}{\partial H_i}\}_{i\in S}
ISK,q≜{∂Hi∂f(Hi,yi,K)}i∈S,然后发送给其他参与方
k
∈
[
K
−
1
]
k\in[K-1]
k∈[K−1]。
I
q
,
k
(
⋅
)
I^{q,k}(·)
Iq,k(⋅)表示从参与方
q
q
q到
k
k
k收集到的信息集合。
对于任意损失函数,定义计算
∇
k
f
(
Θ
;
S
)
\nabla_kf(\Theta;S)
∇kf(Θ;S)所需要的信息集合为:
I
S
−
k
≜
{
I
S
q
,
k
}
q
≠
k
(
5
)
I_{S}^{-k}\triangleq \{I^{q,k}_S\}_{q\not=k} \quad(5)
IS−k≜{ISq,k}q=k(5)
公式(3)便可写成如下:
g
k
(
Θ
,
S
)
=
∇
k
f
(
I
S
−
k
,
θ
k
;
S
)
+
λ
∇
γ
(
θ
k
)
(
6
)
g_k(\Theta, S)= \nabla_kf(I_{S}^{-k},\theta_k;S)+\lambda\nabla_{\gamma}(\theta_k)\quad(6)
gk(Θ,S)=∇kf(IS−k,θk;S)+λ∇γ(θk)(6)
≜
g
k
(
I
S
−
k
,
θ
k
;
S
)
\triangleq g_k(I_{S}^{-k},\theta_k;S)
≜gk(IS−k,θk;S)
最后整体的梯度可以表示如下:
g
(
Θ
;
S
)
≜
[
g
1
(
I
S
−
1
,
θ
1
;
S
)
;
.
.
.
;
g
K
(
I
S
−
K
,
θ
K
;
S
)
]
g(\Theta;S)\triangleq [g_1(I_{S}^{-1},\theta_1;S);...;g_K(I_{S}^{-K},\theta_K;S)]
g(Θ;S)≜[g1(IS−1,θ1;S);...;gK(IS−K,θK;S)]
对任意一个参与方的
θ
\theta
θ更新:
θ
k
=
θ
k
−
η
g
k
(
I
S
−
k
,
θ
k
;
S
)
\theta_k = \theta_{k}-\eta g_k(I_{S}^{-k},\theta_k;S)
θk=θk−ηgk(IS−k,θk;S)
下图是一个并行的FedBCD-p算法:
可以看出模型每训练Q轮进行一次通信 E x c h a n g e ( 1 , 2 , . . . , K , S ) Exchange({1,2,...,K,S}) Exchange(1,2,...,K,S)。通信有两种:①是参与方 k ∈ [ K − 1 ] k\in[K-1] k∈[K−1]计算 I S k , K I_S^{k,K} ISk,K
发送给持有label的参与方 K K K,然后有参与方 K K K计算 I S K , k I_{S}^{K,k} ISK,k发送给其他参与方;②是其他参与方之间的通信。
FedBCD算法的通信流如下:
本文还对算法的收敛性和安全性进行了分析,感兴趣请阅读论文。
下图是网上文章:
三、总结
由于算法模型训练每Q轮进行一次通信,因此,在计算损失函数优化目标时并不是最新计算得到的数据,是有精度损失的。故本算法是一个通信效率和计算效率的trade-off。
论文链接