Communication-Efficient Federated Learning for Heterogeneous Edge Devices Based on Adaptive
Gradient Quantization
Heting Liu, Fang He and Guohong Cao
arXiv
2022
一、动机和贡献
动机:解决FL通信问题的一种重要的方法是 “梯度量化”,但是现在的量化存在以下问题:1)“低精度”量化可以减少数据传输,却引入大的量化误差导致需要更多轮数去训练模型;“高精度”量化量化误差小,却需要传输较多的数据;2)现存量化方式大多基于固定且预设的量化精度,但是一方面由于最优量化精度随时间的推移而不同,另一方面不同client有着不同的通信资源,因此这种静态决定量化精度是不合理的。
贡献:本文通过动态对不同client分配不同的量化精度,旨在尽量减少FL训练过程中的 wall-clock training time,主要包括如下两方面的设计:
- 不同训练轮数有着不同的量化精度:根据量化过程中 “梯度范数gradient norm” 的不同,在训练刚开始时使用大精度量化,在训练后期使用小精度量化 ;
- 不同通信能力client有着不同量化精度:根据client的通信能力,快client赋予大精度量化,慢client赋予小精度量化。
二、算法
2.1 随机均匀量化(QSGD)
假设
s
∈
N
s\in\mathbb{N}
s∈N表示量化精度,
v
=
[
v
1
,
⋯
,
v
d
]
∈
R
d
,
v
≠
0
\mathbf{v}=[v_1,\cdots,v_d]\in\mathbb{R}^d,\mathbf{v}\ne\mathbf{0}
v=[v1,⋯,vd]∈Rd,v=0表示
d
d
d维梯度向量,那么
v
j
v_j
vj 可以由量化函数
Q
s
(
⋅
)
Q_s(\cdot)
Qs(⋅) 定义为:
Q
s
(
v
j
)
=
∣
∣
v
∣
∣
2
⋅
s
i
g
n
(
v
j
)
⋅
ζ
j
(
v
,
s
)
,
Q_{s}(v_{j})=||\mathbf{v}||_{2}\cdot sign(v_{j})\cdot\zeta_{j}(\mathbf{v},s),
Qs(vj)=∣∣v∣∣2⋅sign(vj)⋅ζj(v,s),其中
ζ
j
(
v
,
s
)
\zeta_{j}(\mathbf{v},s)
ζj(v,s) 表示随机变量,定义为:
ζ
j
(
v
,
s
)
=
{
l
/
s
,
w
i
t
h
p
r
o
b
a
b
i
l
i
t
y
(
1
−
∣
v
j
∣
∣
∣
v
∣
∣
2
s
+
l
)
(
l
+
1
)
/
s
,
o
t
h
e
r
w
i
s
e
.
\zeta_j(\mathbf{v},s)=\left\{\begin{array}{cc}l/s,&with~probability~(1-\frac{|v_j|}{||\mathbf{v}||_2}s+l)\\(l+1)/s,&otherwise.\end{array}\right.
ζj(v,s)={l/s,(l+1)/s,with probability (1−∣∣v∣∣2∣vj∣s+l)otherwise.其中,
0
≤
l
<
s
0\leq l<s
0≤l<s 是一个整数,使得
∣
v
j
∣
∣
∣
v
∣
∣
2
∈
[
l
/
s
,
(
l
+
1
)
/
s
]
\frac{|v_{j}|}{||\mathbf{v}||_{2}}\in[l/s,(l+1)/s]
∣∣v∣∣2∣vj∣∈[l/s,(l+1)/s]。特别的,当
v
=
0
\mathbf{v}=\mathbf{0}
v=0,可以有
Q
s
(
v
)
=
0
Q_s(\mathbf{v})=\mathbf{0}
Qs(v)=0。
QSGD可以解释为:将 [ 0 , ∥ v ∥ 2 ] [0,\|\mathbf{v}\|_2] [0,∥v∥2] 之间 “均匀” 划分为 s − 1 s-1 s−1(包括一个符号位) 个桶,因此桶的端点可以表示为 0 = τ 1 < τ 2 < ⋯ < τ s = ∣ ∣ v ∣ ∣ 2 0=\tau_{1}<\tau_{2}<\cdots<\tau_{s}=||\mathbf{v}||_{2} 0=τ1<τ2<⋯<τs=∣∣v∣∣2。因为 ∣ v j ∣ ∈ [ 0 , ∣ ∣ v ∣ ∣ 2 ] |v_{j}|\in[0,||\mathbf{v}||_{2}] ∣vj∣∈[0,∣∣v∣∣2],因此每个 ∣ v j ∣ |v_j| ∣vj∣ 必定属于某个桶 [ τ i , τ i + 1 ) [\tau_i,\tau_{i+1}) [τi,τi+1)。最后,根据概率( ζ j ( v , s ) \zeta_j(\mathbf{v},s) ζj(v,s))决定 Q s ( v j ) Q_s(v_j) Qs(vj) 取左边界 τ i \tau_i τi 还是有边界 τ i + 1 \tau_{i+1} τi+1。
注:这里 s s s 有两层含义,表达量化后梯度所需要的比特数或者真值,需要注意区分。
2.2 Overview of AdaGQ
上图展示了 AdaGQ 的基本流程,其中黑色加粗字体表示的是这篇文章的创新之处,具体表现为如下两方面:
- adaptive:根据 loss decrease rate 和 gradient norm 在不同训练轮数给出不同的量化精度;
- heterogeneous:根据 通信时间 的差异,给不同client不同量化精度以对齐通信时间。
注:与之前QSGD中 s s s 的两层含义不同,在后续写作中, s s s 表示不带符号位的量化后梯度的真值, b = ⌊ log 2 ( s ) + 1 ⌋ b=\lfloor\log_{2}(s)+1\rfloor b=⌊log2(s)+1⌋ 表示相应的比特数。
2.3 Adaptive Quantization
定义 loss decrease rate
R
k
R_k
Rk 为:
R
k
=
(
L
k
−
1
−
L
k
)
/
T
k
−
1
,
k
,
R_k=(L_{k-1}-L_k)/T_{k-1,k},
Rk=(Lk−1−Lk)/Tk−1,k,其中,
L
k
L_k
Lk 表示
k
k
k 轮时所有客户端的平均损失;
T
k
−
1
,
k
T_{k-1,k}
Tk−1,k 表示
k
−
1
k-1
k−1 轮结束到
k
k
k 轮结束所需的时间(这里应该也是平均时间,因为所有client的执行时间都将被对齐)。
假设
R
k
∗
R_k^*
Rk∗ 表示
k
k
k 轮时由最佳量化精度
s
k
∗
s_k^*
sk∗ 得到的最佳 loss decrease rate,那么定义函数:(
L
L
L 和
T
T
T 都是关于
s
s
s 的函数,因此
R
R
R 也是关于
s
s
s 的函数)
f
(
s
k
)
=
R
k
∗
−
R
k
.
f(s_k)=R_k^*-R_k.
f(sk)=Rk∗−Rk.因此,量化精度
s
s
s 可以以如下方式更新:
s
k
+
1
=
s
k
−
λ
∇
f
(
s
k
)
,
s_{k+1}=s_k-\lambda\nabla f(s_k),
sk+1=sk−λ∇f(sk),其中,
λ
\lambda
λ 表示步长。但是遗憾的是,由于函数
f
(
s
k
)
f(s_k)
f(sk) 关于自变量
s
k
s_k
sk 的具体表达形式不清楚,所以直接求导数
∇
f
(
s
k
)
\nabla f(s_k)
∇f(sk) 是不可行的。因此这篇文章利用和 “导数定义” 相似的思想解决,即:选取一个靠近
s
k
s_k
sk 的量化精度
s
k
′
s_k^\prime
sk′,并得到相应的
R
k
′
R_k^\prime
Rk′,这样就可以得到导数
∇
f
(
s
k
)
\nabla f(s_k)
∇f(sk) 的符号为:
s
i
g
n
(
∇
f
(
s
k
)
)
=
s
i
g
n
(
R
k
′
−
R
k
s
k
−
s
k
′
)
sign(\nabla f(s_k))=sign(\frac{R_k^{\prime}-R_k}{s_k-s_k^{\prime}})
sign(∇f(sk))=sign(sk−sk′Rk′−Rk) 这里如何得到
R
k
′
R_k^\prime
Rk′ 将在 “Implementation of AdaGQ“ 小节中给出。因此,更新规则变为:
{
s
^
k
+
1
=
s
k
−
λ
1
,
i
f
s
i
g
n
(
∇
f
(
s
k
)
)
=
1
s
^
k
+
1
=
s
k
+
λ
2
,
i
f
s
i
g
n
(
∇
f
(
s
k
)
)
=
−
1.
\left\{\begin{matrix}&\hat{s}_{k+1}=s_k-\lambda_1,&if&sign(\nabla f(s_k))=1\\&\hat{s}_{k+1}=s_k+\lambda_2,&if&sign(\nabla f(s_k))=-1.\end{matrix}\right.
{s^k+1=sk−λ1,s^k+1=sk+λ2,ififsign(∇f(sk))=1sign(∇f(sk))=−1.其中,
λ
1
=
s
k
2
,
λ
2
=
2
×
s
k
\lambda_1=\frac{s_k}{2},\lambda_2=2\times s_k
λ1=2sk,λ2=2×sk。
注:梯度其实最重要的就是表示更新的方向(即它的符号),至于其绝对值大小可以由”步长“决定,因此这里只考虑梯度的符号是合理的。
最后,根据 ”梯度范数“ 对
s
^
k
+
1
\hat{s}_{k+1}
s^k+1 进行校准:
s
k
+
1
=
s
^
k
+
1
+
λ
g
(
log
2
∣
∣
g
k
∣
∣
−
log
2
∣
∣
g
k
−
1
∣
∣
)
s_{k+1}=\hat{s}_{k+1}+\lambda_{\mathbf{g}}(\log_{2}||\mathbf{g}_{k}||-\log_{2}||\mathbf{g}_{k-1}||)
sk+1=s^k+1+λg(log2∣∣gk∣∣−log2∣∣gk−1∣∣)其中,
λ
g
\lambda_{\mathbf{g}}
λg 表示相应的系数。
2.4 Heterogeneous Quantization
根据client ”历史运行时间“ 确定相应的量化精度,定义为:
E
(
t
i
,
k
+
1
r
)
=
E
(
t
i
,
k
+
1
c
p
)
+
E
(
t
i
,
k
+
1
c
m
)
≈
E
(
t
i
,
k
+
1
c
p
)
+
b
i
,
k
+
1
×
E
(
P
r
i
.
k
+
1
t
r
a
n
s
)
,
\mathbb{E}(t_{i,k+1}^r)=\mathbb{E}(t_{i,k+1}^{cp})+\mathbb{E}(t_{i,k+1}^{cm})\approx\mathbb{E}(t_{i,k+1}^{cp})+b_{i,k+1}\times\mathbb{E}(\frac{P}{r_{i.k+1}^{trans}}),
E(ti,k+1r)=E(ti,k+1cp)+E(ti,k+1cm)≈E(ti,k+1cp)+bi,k+1×E(ri.k+1transP),其中,
t
i
,
k
+
1
c
p
t_{i,k+1}^{cp}
ti,k+1cp 表示client执行 SGD和量化梯度的时间;
t
i
,
k
+
1
c
m
t_{i,k+1}^{cm}
ti,k+1cm 表示上传量化后梯度到sever的时间;
P
P
P 是一个常数表示梯度总数;
r
i
.
k
+
1
t
r
a
n
s
r_{i.k+1}^{trans}
ri.k+1trans 表示client
i
i
i 在
k
+
1
k+1
k+1 轮时的数据传输率。
因此,对齐通信时间可以描述为
E
(
t
1
,
k
+
1
r
)
=
E
(
t
2
,
k
+
1
r
)
=
⋯
=
E
(
t
n
,
k
+
1
r
)
\mathbb{E}(t_{1,k+1}^{r})=\mathbb{E}(t_{2,k+1}^{r})=\cdots=\mathbb{E}(t_{n,k+1}^{r})
E(t1,k+1r)=E(t2,k+1r)=⋯=E(tn,k+1r)。那么对于client
i
i
i 和
j
j
j,其量化精度的关系可以表示为:
b
j
,
k
+
1
=
1
E
(
P
r
j
,
k
+
1
t
r
a
n
s
)
(
E
(
t
i
,
k
+
1
c
p
)
−
E
(
t
j
,
k
+
1
c
p
)
+
b
i
,
k
+
1
×
E
(
P
r
i
,
k
+
1
t
r
a
n
s
)
)
b_{j,k+1}=\frac{1}{\mathbb{E}(\frac{P}{r_{j,k+1}^{trans}})}(\mathbb{E}(t_{i,k+1}^{cp})-\mathbb{E}(t_{j,k+1}^{cp})+b_{i,k+1}\times\mathbb{E}(\frac{P}{r_{i,k+1}^{trans}}))
bj,k+1=E(rj,k+1transP)1(E(ti,k+1cp)−E(tj,k+1cp)+bi,k+1×E(ri,k+1transP))这里需要定义两个变量:
- E ( t i , k + 1 c p ) = 1 k ∑ k ′ = 1 k t i , k ′ c p \begin{aligned}\mathbb{E}(t_{i,k+1}^{cp})=\frac{1}{k}\sum_{k'=1}^{k}t_{i,k'}^{cp}\end{aligned} E(ti,k+1cp)=k1k′=1∑kti,k′cp,根据历史时间的平均得到;
- E ( P r i , k + 1 t r a n s ) ≈ P r i , k t r a n s = t i , k c m / b i , k \mathbb{E}(\frac{P}{r_{i,k+1}^{t\boldsymbol{r}a\boldsymbol{n}s}})\approx\frac{P}{r_{i,k}^{t\boldsymbol{r}a\boldsymbol{n}s}}=t_{i,k}^{c\boldsymbol{m}}/b_{i,k} E(ri,k+1transP)≈ri,ktransP=ti,kcm/bi,k,认为传出率在小时间范围内的变化是不明显的。
因此,如果给定 client
i
i
i 的量化精度,client
j
j
j 的量化精度可以表示为:
b
j
,
k
+
1
=
b
j
,
k
t
j
,
k
c
m
(
1
k
∑
k
′
=
1
k
t
i
,
k
′
c
p
−
1
k
∑
k
′
=
1
k
−
1
t
j
,
k
′
c
p
+
b
i
,
k
+
1
×
t
i
,
k
c
m
b
i
,
k
)
,
∀
j
∈
{
1
,
⋯
,
n
}
,
j
≠
i
.
\begin{aligned}b_{j,k+1}=\frac{b_{j,k}}{t_{j,k}^{cm}}(\frac1k\sum_{k^{\prime}=1}^{k}t_{i,k^{\prime}}^{cp}-\frac1k\sum_{k^{\prime}=1}^{k-1}t_{j,k^{\prime}}^{cp}+b_{i,k+1}\times\frac{t_{i,k}^{cm}}{b_{i,k}}),\forall j\in\{1,\cdots,n\},j\neq i.\end{aligned}
bj,k+1=tj,kcmbj,k(k1k′=1∑kti,k′cp−k1k′=1∑k−1tj,k′cp+bi,k+1×bi,kti,kcm),∀j∈{1,⋯,n},j=i.
2.5 Implementation of AdaGQ
上图表示 AdaGQ 在
k
+
1
k+1
k+1 轮时的时间线图。其中,
t
k
+
1
d
o
w
n
t_{k+1}^{down}
tk+1down 表示sever发送同时client接收模型所需要的时间;
t
k
+
1
s
e
v
e
r
t_{k+1}^{sever}
tk+1sever sever执行模型聚合的时间。
关于如何得到 R k ′ R_k^{\prime} Rk′,分为如下两个步骤:
- 这篇文章定义 s k = 1 n ∑ i = 1 n s i , k s_{k}=\frac{1}{n}\sum_{i=1}^{n}s_{i,k} sk=n1∑i=1nsi,k,且 s k ′ = ⌊ s k / 2 ⌋ s_{k}^{\prime}=\lfloor s_{k}/2\rfloor sk′=⌊sk/2⌋(即比特数 b k ′ = b k − 1 b_k^{\prime}=b_k-1 bk′=bk−1)。
- 同时sever端定义
k
−
1
k-1
k−1 到
k
k
k 轮之间的执行时间
T
k
−
1
,
k
=
m
a
x
{
t
i
,
k
c
p
+
t
i
,
k
c
m
+
t
i
,
k
d
o
w
n
}
+
t
k
s
e
r
v
e
r
.
T_{k-1,k}=max\{t_{i,k}^{cp}+t_{i,k}^{cm}+t_{i,k}^{down}\}+t_{k}^{server}.
Tk−1,k=max{ti,kcp+ti,kcm+ti,kdown}+tkserver.可以容易知道,
T
k
−
1
,
k
′
T_{k-1,k}^{\prime}
Tk−1,k′ 和
T
k
−
1
,
k
T_{k-1,k}
Tk−1,k 的主要差异是关于 传输时间
t
i
,
k
′
c
m
t_{i,k}^{\prime cm}
ti,k′cm 和
t
i
,
k
′
c
m
t_{i,k}^{\prime cm}
ti,k′cm,而传输时间的差异和相应的比特数成比例关系的(即
b
i
,
k
′
b_{i,k}^\prime
bi,k′ 和
b
i
,
k
b_{i,k}
bi,k),因此可以得到
T
k
−
1
,
k
′
=
m
a
x
{
t
i
,
k
c
p
+
⌊
log
2
(
s
i
,
k
′
)
⌋
+
1
⌊
log
2
(
s
i
,
k
)
⌋
+
1
t
i
,
k
c
m
+
t
i
,
k
d
o
w
n
}
+
t
k
s
e
r
v
e
r
.
T_{k-1,k}^{\prime}=max\{t_{i,k}^{cp}+\frac{\lfloor\log_{2}(s_{i,k}^{\prime})\rfloor+1}{\lfloor\log_{2}(s_{i,k})\rfloor+1}t_{i,k}^{cm}+t_{i,k}^{down}\}+t_{k}^{server}.
Tk−1,k′=max{ti,kcp+⌊log2(si,k)⌋+1⌊log2(si,k′)⌋+1ti,kcm+ti,kdown}+tkserver.
这样就可以得到相应的 R k ′ R_k^{\prime} Rk′。
关于如何根据client通信异质得到相应的量化精度。这篇文章中只是说明了:如果得到 client i i i 的量化精度就可以得出 client j j j 的量化精度。那么第一个client 的量化精度如何得出呢?原文中没有说明,我的理解是 ”可以给速度中等的client赋予平均精度,然后依次计算其他client的量化精度“。
AdaGQ 伪代码如下:
三、讨论
本文主要关注的是FL中,尽量减少总训练时间的问题(包括减少每轮执行时间 和 总执行轮数)。同时为了兼顾 模型准确性,根据量化过程中使用范数的特点,在训练开始时尽量使用大精度,在训练后期使用小精度。
主要特点是:
- 提出对不同训练时期使用不同的量化精度
- 量化了各个client通信能力,即使用时间来衡量
不足之处:
- 没有考虑对不同量化精度的模型进行个性化聚合,只是直接使用了FedAvg中根据数据量的大小聚合
- 只考虑 client 之间通信能力的差异,对于 ”算力、存储等“差异没有考虑
- 本文出现的时间线图感觉并行能力不强,是否具有改善的可能