在这篇blog中我们一起来阅读一下 On the convergence of FedAvg on non-iid data 这篇 ICLR 2020 的paper.
主要目的
本文的主要目的是证明联邦学习算法的收敛性。与之前其他工作中的证明不同,本文的证明更贴近于实际联邦学习的场景。特别的,
- 所有用户的数据non-iid分布;
- 每次只有一部分用户参与FedAvg.
系统模型
考虑一个联邦学习系统 with N N N 用户和一个PS. 每用户有一些local data,训练发生在用户处,每隔一段时间用户上传自己学习的模型来做FedAvg.
将第 k k k 个用户的数据记为 x = { x k , 1 , x k , 2 , x k , 3 , . . . , x k , n k } \bm{x}=\{x_{k,1},x_{k,2},x_{k,3},...,x_{k,n_k}\} x={
xk,1,xk,2,xk,3,...,xk,nk}, 每个人都有一个学习目标,即最小化 loss 函数
F k ( w ) = ∑ j = 1 n k ℓ k ( w , x k , j ) (1) F_k(\bm{w})=\sum_{j=1}^{n_k}\ell_k(\bm{w},x_{k,j}) \tag{1} Fk(w)=j=1∑nkℓk(w,xk,j)(1)
其中 ℓ k ( w , x k , j ) \ell_k(\bm{w},x_{k,j}) ℓk(w,xk,j) 是每个训练数据的 loss. F k ( w ) F_k(\bm{w}) Fk(w) 相当于是每个人所有数据上的loss,如果仅仅做local training, 那最终每个用户会 arrive at
Local minimum : F k ∗ = min w F k \text{Local minimum}:~~~~F_k^*=\min_{\bm{w}} F_k Local minimum: Fk∗=wminFk
而FL考虑的是一种分布式的优化,即我们要minimize的目标函数
Global minimum : F ∗ = min w ∑ k = 1 N p k F k ( w ) \text{Global minimum}:~~~~F^*=\min_{\bm{w}} \sum_{k=1}^{N} p_k F_k(\bm{w}) Global minimum: F∗=wmink=1∑NpkFk(w)
其中 p k p_k pk 是一个distribution用来表示每个用户所占的权重。换句话说,我们最终想找到一个共同的 w \bm{w} w 来最小化每个用户 loss 的一个加权和。
To this end, 本文考虑FedAvg, 并证明其能收敛到 global optimum.
FedAvg 的具体步骤描述如下:首先,我们按单次SGD为一个时间刻度把时间轴分为离散的slot t = 1 , 2 , 3 , . . . , T t=1,2,3,...,T t=1,2,3,...,T, 即总共进行 T T T 次 local SGD, 每次 SGD每个用户从自己的数据集中随机均匀的采样出一个数据来进行训练。特别的,每隔 E E E slots, 所有 active users 把自己的本地参数发送给PS进行 FedAvg,之后PS会把avg后的参数发还给各个用户。以上模型用数学语言可以写为以下两步:
Local training
每个用户在第 t t t 个时刻基于 w t k \bm{w}^k_t wtk 进行 SGD, 得到
v t + 1 k = w t k − η t ∇ ℓ k ( w t k , ξ t k ) (2) \bm{v}^k_{t+1}=\bm{w}^k_t-\eta_{t}\nabla \ell_k(\bm{w}^k_t,\xi^k_t) \tag{2} vt+1k=wtk−ηt∇ℓk(wtk,ξtk)(2)
其中 ξ t k \xi^k_t ξtk 是从本地数据中随机采样出的一个sample。注意,这样单步SGD得到的 v t + 1 k \bm{v}^k_{t+1} vt+1k 只是一个中间变量而不是下一时刻的 w t + 1 k \bm{w}^k_{t+1} wt+1k,因为我们还有可能做 FedAvg。 更具体地说,在 E E E 的非整数倍slot上,
w t + 1 k = v t + 1 k , if t + 1 ∉ J E = { n E : n = 1 , 2 , . . . } . \bm{w}^k_{t+1}=\bm{v}^k_{t+1},~~~~\text{if}~~t+1\notin\mathcal{J}_E=\{nE:n=1,2,...\}. wt+1k=vt+1k, if t+1∈/JE={
nE:n=1,2,...}.
而在 E E E 的整数倍slot上,我们还得额外做 FedAvg.
FedAvg
若下一时刻是 E E E的整数倍周期,即 t + 1 ∈ J E = { n E : n = 1 , 2 , . . . } t+1\in\mathcal{J}_E=\{nE:n=1,2,...\} t+1∈JE={
nE:n=1,2,...},我们进行FedAvg,此时
w t + 1 k = ∑ k = 1 N p k v t + 1 k (3) \bm{w}^k_{t+1}=\sum_{k=1}^N p_k \bm{v}^k_{t+1} \tag{3} wt+1k=k=1∑Npkvt+1k(3)
注意,这里面我们假设每个人都参与更新,稍后我们会release这个条件允许PS按照某种分布采样一部分人进行更新。
小结
如果我们从每个用户的角度看,它的参数变化可以用下图归纳 ( E = 3 E=3 E=3)。
几个假设
本文的推导基于以下假设。
Assumption 1 ( L L L-smoothness). 所有用户的 loss 函数 { F k : k = 1 , 2 , . . . , N } \{F^k:k=1,2,...,N\} { Fk:k=1,2,...,N} 都是 L-smooth.
F k ( x 2 ) − F k ( x 1 ) ≤ ∇ f ( x 1 ) ⊤ ( x 2 − x 1 ) + L 2 ∥ x 2 − x 1 ∥ 2 F^k(\bm{x_2})-F^k(\bm{x_1})\leq \nabla f(\bm{x_1})^\top (\bm{x_2-x_1}) + \frac{L}{2}\|\bm{x_2-x_1}\|^2 Fk(x2)−Fk(x1)≤∇f(x1)⊤(x2−x1)+2L∥x2−x1∥2
Assumption 2 ( μ \mu μ-strongly convex). 所有用户的 loss 函数 { F k : k = 1 , 2 , . . . , N } \{F^k:k=1,2,...,N\} { Fk