On the convergence of FedAvg on non-iid data

在这篇blog中我们一起来阅读一下 On the convergence of FedAvg on non-iid data 这篇 ICLR 2020 的paper.

主要目的

本文的主要目的是证明联邦学习算法的收敛性。与之前其他工作中的证明不同,本文的证明更贴近于实际联邦学习的场景。特别的,

  1. 所有用户的数据non-iid分布;
  2. 每次只有一部分用户参与FedAvg.

系统模型

考虑一个联邦学习系统 with N N N 用户和一个PS. 每用户有一些local data,训练发生在用户处,每隔一段时间用户上传自己学习的模型来做FedAvg.

将第 k k k 个用户的数据记为 x = { x k , 1 , x k , 2 , x k , 3 , . . . , x k , n k } \bm{x}=\{x_{k,1},x_{k,2},x_{k,3},...,x_{k,n_k}\} x={ xk,1,xk,2,xk,3,...,xk,nk}, 每个人都有一个学习目标,即最小化 loss 函数
F k ( w ) = ∑ j = 1 n k ℓ k ( w , x k , j ) (1) F_k(\bm{w})=\sum_{j=1}^{n_k}\ell_k(\bm{w},x_{k,j}) \tag{1} Fk(w)=j=1nkk(w,xk,j)(1)

其中 ℓ k ( w , x k , j ) \ell_k(\bm{w},x_{k,j}) k(w,xk,j) 是每个训练数据的 loss. F k ( w ) F_k(\bm{w}) Fk(w) 相当于是每个人所有数据上的loss,如果仅仅做local training, 那最终每个用户会 arrive at
Local minimum :      F k ∗ = min ⁡ w F k \text{Local minimum}:~~~~F_k^*=\min_{\bm{w}} F_k Local minimum:    Fk=wminFk

而FL考虑的是一种分布式的优化,即我们要minimize的目标函数
Global minimum :      F ∗ = min ⁡ w ∑ k = 1 N p k F k ( w ) \text{Global minimum}:~~~~F^*=\min_{\bm{w}} \sum_{k=1}^{N} p_k F_k(\bm{w}) Global minimum:    F=wmink=1NpkFk(w)

其中 p k p_k pk 是一个distribution用来表示每个用户所占的权重。换句话说,我们最终想找到一个共同的 w \bm{w} w 来最小化每个用户 loss 的一个加权和。

To this end, 本文考虑FedAvg, 并证明其能收敛到 global optimum.

FedAvg 的具体步骤描述如下:首先,我们按单次SGD为一个时间刻度把时间轴分为离散的slot t = 1 , 2 , 3 , . . . , T t=1,2,3,...,T t=1,2,3,...,T, 即总共进行 T T T 次 local SGD, 每次 SGD每个用户从自己的数据集中随机均匀的采样出一个数据来进行训练。特别的,每隔 E E E slots, 所有 active users 把自己的本地参数发送给PS进行 FedAvg,之后PS会把avg后的参数发还给各个用户。以上模型用数学语言可以写为以下两步:

Local training

每个用户在第 t t t 个时刻基于 w t k \bm{w}^k_t wtk 进行 SGD, 得到
v t + 1 k = w t k − η t ∇ ℓ k ( w t k , ξ t k ) (2) \bm{v}^k_{t+1}=\bm{w}^k_t-\eta_{t}\nabla \ell_k(\bm{w}^k_t,\xi^k_t) \tag{2} vt+1k=wtkηtk(wtk,ξtk)(2)

其中 ξ t k \xi^k_t ξtk 是从本地数据中随机采样出的一个sample。注意,这样单步SGD得到的 v t + 1 k \bm{v}^k_{t+1} vt+1k 只是一个中间变量而不是下一时刻的 w t + 1 k \bm{w}^k_{t+1} wt+1k,因为我们还有可能做 FedAvg。 更具体地说,在 E E E 的非整数倍slot上,
w t + 1 k = v t + 1 k ,     if   t + 1 ∉ J E = { n E : n = 1 , 2 , . . . } . \bm{w}^k_{t+1}=\bm{v}^k_{t+1},~~~~\text{if}~~t+1\notin\mathcal{J}_E=\{nE:n=1,2,...\}. wt+1k=vt+1k,    if  t+1/JE={ nE:n=1,2,...}.

而在 E E E 的整数倍slot上,我们还得额外做 FedAvg.

FedAvg

若下一时刻是 E E E的整数倍周期,即 t + 1 ∈ J E = { n E : n = 1 , 2 , . . . } t+1\in\mathcal{J}_E=\{nE:n=1,2,...\} t+1JE={ nE:n=1,2,...},我们进行FedAvg,此时
w t + 1 k = ∑ k = 1 N p k v t + 1 k (3) \bm{w}^k_{t+1}=\sum_{k=1}^N p_k \bm{v}^k_{t+1} \tag{3} wt+1k=k=1Npkvt+1k(3)

注意,这里面我们假设每个人都参与更新,稍后我们会release这个条件允许PS按照某种分布采样一部分人进行更新。

小结

如果我们从每个用户的角度看,它的参数变化可以用下图归纳 ( E = 3 E=3 E=3)。

在这里插入图片描述

几个假设

本文的推导基于以下假设。

Assumption 1 ( L L L-smoothness). 所有用户的 loss 函数 { F k : k = 1 , 2 , . . . , N } \{F^k:k=1,2,...,N\} { Fk:k=1,2,...,N} 都是 L-smooth.
F k ( x 2 ) − F k ( x 1 ) ≤ ∇ f ( x 1 ) ⊤ ( x 2 − x 1 ) + L 2 ∥ x 2 − x 1 ∥ 2 F^k(\bm{x_2})-F^k(\bm{x_1})\leq \nabla f(\bm{x_1})^\top (\bm{x_2-x_1}) + \frac{L}{2}\|\bm{x_2-x_1}\|^2 Fk(x2)Fk(x1)f(x1)(x2x1)+2Lx2x12

Assumption 2 ( μ \mu μ-strongly convex). 所有用户的 loss 函数 { F k : k = 1 , 2 , . . . , N } \{F^k:k=1,2,...,N\} { Fk

评论 13
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值