文章链接:Link
-
摘要:在联邦学习(FL)中,由于多次本地更新和non-i.i.d.数据集,客户端容易过度拟合到局部最优,偏离全局目标,影响性能。以往的研究大多从优化的角度出发,只注重增强局部目标与全局目标的一致性(FedLin, FedDyn…),以缓解client drift。文章提出了一种新的通用算法FedSMOO,将优化目标和泛化目标结合起来,有效地提高了FL的性能。FedSMOO采用动态正则化器来保证局部最优点向全局目标靠近,同时通过全局锐度感知最小化(sharp Aware Minimization, SAM)优化器对其进行修正,以搜索一致的平坦最小值。理论分析表明,FedSMOO具有较快的 O ( 1 T ) \mathcal{O}(\frac{1}{T}) O(T1)收敛速度和较低的泛化界。
-
Main Contributions:
(1) 相较于FedSAM只寻找local flatness,FedSMOO通过迭代更新局部扰动 s i s_i si促进global generality
(2) 类似于FedDyn,增加了一个正则化项来保证global consistency -
Motivation:
(1) Sharp Aware Minimization, SAM: 旨在优化一个极小-极大问题
min w { f s ( w ) ≜ max ∥ s ∥ ≤ r f ( w + s ) } \min _w\left\{f_s(w) \triangleq \max _{\|s\| \leq r} f(w+s)\right\} minw{fs(w)≜max∥s∥≤rf(w+s)}
SAM采用近似方式求解此问题,通过在 w w w处进行一阶泰勒展开,此时 w w w处最优的 s s s取值易得
s ∗ ( w ) ≈ arg max ∥ s ∥ ≤ r { f ( w ) + s ⊤ ∇ f ( w ) } = r ⋅ ∇ f ( w ) / ∥ ∇ f ( w ) ∥ s^*(w) \approx \underset{\|s\| \leq r}{\arg \max }\left\{f(w)+s^{\top} \nabla f(w)\right\} =r \cdot \nabla f(w) /\|\nabla f(w)\| s∗(w)≈∥s∥≤rargmax{f(w)+s⊤∇f(w)}=r⋅∇f(w)/∥∇f(w)∥
故SAM可被视作用于寻找当前损失"flat landscape"的基于GD/SGD的改进optimizer,在集中式设置下基于SGD的更新流程如下( t t t-round):
Perform a SGD step with batch ξ t \xi_t ξt – Find the parameter w + r ∇ f ( w ) ∥ ∇ f ( w ) ∥ w+r\frac{\nabla f(w)}{\| \nabla f(w) \|} w+r∥∇f(w)∥∇f(w) – Calculate the SGD gradient then update(2) SAM拓展到联邦设置(FedSAM)时存在的问题
当拓展到联邦设置下,由于non-i.i.d.数据,局部的pertubation的聚合并不等于全局(基于所有设备本地数据分布)的pertubation,即局部的flat minima聚合后不一定对于全局损失的flat minima,使得模型聚合后得到的并不一定对应全局平坦值,泛化能力的提升是不确定的
1 m ∑ i = 1 m s i ≠ s \frac{1}{m}\sum_{i=1}^{m}s_i \ne s m1∑i=1msi=s -
Methodology
1.为解决global flat landscape问题,即保证local pertubation趋向于global pertubation,文章首先重构优化问题:(这一步相当于解决inner level的极大化问题,即促进global generality)
min w { F ( w ) = 1 m ∑ i ∈ [ m ] F i ( w ) } , F i ( w ) ≜ max ∥ s ∥ ≤ r f i ( w + s ) , \begin{array}{r} \min _w\left\{\mathcal{F}(w)=\frac{1}{m} \sum_{i \in[m]} \mathcal{F}_i(w)\right\}, \\ \mathcal{F}_i(w) \triangleq \max _{\|s\| \leq r} f_i(w+s), \end{array} minw{F(w)=m1∑i∈[m]Fi(w)},Fi(w)≜max∥s∥≤rfi(w+s),
引入约束,上述问题可进一步重写为
min w i = w { F = 1 m ∑ i ∈ [ m ] F i ( w , w i , s , s i ) } , F i ( w , w i , s , s i ) ≜ max ∥ s i ∣ ≤ r , s i = s f i ( w i + s i ) . \begin{aligned} & \min _{w_i=w}\{\mathcal{F}\left.=\frac{1}{m} \sum_{i \in[m]} \mathcal{F}_i\left(w, w_i, s, s_i\right)\right\}, \\ & \mathcal{F}_i\left(w, w_i, s, s_i\right) \triangleq \max _{\substack{\| s_i \mid \leq r, s_i=s}} f_i\left(w_i+s_i\right) . \end{aligned} wi=wmin{F=m1i∈[m]∑Fi(w,wi,s,si)⎭ ⎬ ⎫,Fi(w,wi,s,si)≜∥si∣≤r,si=smaxfi(wi+si).
为求解此问题,首先采用同SAM一致的一阶泰勒展开,随后定义增广拉格朗日函数,通过一阶和二阶项惩罚 s i = s s_i=s si=s
L i s : f i ( w i ) + s i ⊤ ∇ f i ( w i ) + μ i ⊤ ( s − s i ) + 1 2 α ∥ s − s i ∥ 2 , s . t . ∥ s i ∥ ≤ r \mathcal{L}_i^s: f_i\left(w_i\right)+s_i^{\top} \nabla f_i\left(w_i\right)+\mu_i^{\top}\left(s-s_i\right)+\frac{1}{2 \alpha}\left\|s-s_i\right\|^2, s.t. \|s_i\| \leq r Lis:fi(wi)+si⊤∇fi(wi)+μi⊤(s−si)+2α1∥s−si∥2,s.t.∥si∥≤r
即定义了一个modified local loss function,优化变量变为 { s i , μ i , s } \{ s_i, \mu_i, s \} {si,μi,s},随后采用ADMM算法进行交替更新:
s ^ i = arg max ∥ s i ∥ ≤ r { s i ⊤ ( ∇ f i ( w i ) − μ i ) + 1 2 α ∥ s i − s ∥ 2 } = arg max ∥ s i ∥ ≤ r { 1 2 α ∥ s i + s ˉ i ∥ 2 } , \begin{aligned} \hat{s}_i & =\underset{\left\|s_i\right\| \leq r}{\arg \max }\left\{s_i^{\top}\left(\nabla f_i\left(w_i\right)-\mu_i\right)+\frac{1}{2 \alpha}\left\|s_i-s\right\|^2\right\} \\ & =\underset{\left\|s_i\right\| \leq r}{\arg \max }\left\{\frac{1}{2 \alpha}\left\|s_i+\bar{s}_i\right\|^2\right\}, \end{aligned} s^i=∥si∥≤rargmax{si⊤(∇fi(wi)−μi)+2α1∥si−s∥2}=∥si∥≤rargmax{2α1∥si+sˉi∥2},
s ˉ i = α ( ∇ f i ( w i ) − μ i ) − s , s ^ i = r s ˉ i / ∥ s ˉ i ∥ , μ i = μ i + 1 α ( s ^ i − s ) \bar{s}_i=\alpha\left(\nabla f_i\left(w_i\right)-\mu_i\right)-s , \hat{s}_i=r \bar{s}_i /\left\| \bar{s}_i\right\| , \mu_i=\mu_i+\frac{1}{\alpha}\left(\hat{s}_i-s\right) sˉi=α(∇fi(wi)−μi)−s,s^i=rsˉi/∥sˉi∥,μi=μi+α1(s^i−s)
同时,模型聚合得到相应的modified global loss function也通过此方式优化
s ^ = arg max ∥ s ∥ ≤ r 1 m ∑ i ∈ [ m ] { s ⊤ μ i + 1 2 α ∥ s − s ^ i ∥ 2 } = arg max ∥ s ∥ ≤ r { 1 2 α 1 m ∑ i ∈ [ m ] ∥ s + α μ i − s ^ i ∥ 2 } . \begin{aligned} \hat{s} & =\underset{\|s\| \leq r}{\arg \max } \frac{1}{m} \sum_{i \in[m]}\left\{s^{\top} \mu_i+\frac{1}{2 \alpha}\left\|s-\hat{s}_i\right\|^2\right\} \\ & =\underset{\|s\| \leq r}{\arg \max }\left\{\frac{1}{2 \alpha} \frac{1}{m} \sum_{i \in[m]}\left\|s+\alpha \mu_i-\hat{s}_i\right\|^2\right\} . \end{aligned} s^=∥s∥≤rargmaxm1i∈[m]∑{s⊤μi+2α1∥s−s^i∥2}=∥s∥≤rargmax⎩ ⎨ ⎧2α1m1i∈[m]∑∥s+αμi−s^i∥2⎭ ⎬ ⎫.
s = 1 m ∑ i ∈ [ m ] ( α μ i − s ^ i ) , s ^ = r s / ∥ s ∥ s=\frac{1}{m} \sum_{i \in[m]}\left(\alpha \mu_i-\hat{s}_i\right) , \hat{s}=r s /\|s\| s=m1∑i∈[m](αμi−s^i),s^=rs/∥s∥
2.为解决client drift问题,文章采用FedDyn方式引入动态正则化项:(这一步相当于解决outer level的极小化问题,即促进global consistency)
L : 1 m ∑ i { F i + λ i ⊤ ( w t − w i ) + 1 2 β ∥ w t − w i ∥ 2 } \mathcal{L}: \frac{1}{m} \sum_i\left\{\mathscr{F}_i+\lambda_i^{\top}\left(w^t-w_i\right)+\frac{1}{2 \beta}\left\|w^t-w_i\right\|^2\right\} L:m1∑i{Fi+λi⊤(wt−wi)+2β1∥wt−wi∥2}
将上述问题分布到client端逐轮求解,即
w i , K t = arg min w i { F i − λ i ⊤ w i + 1 2 β ∥ w t − w i ∥ 2 } w_{i, K}^t=\underset{w_i}{\arg \min }\left\{\mathscr{F}_i-\lambda_i^{\top} w_i+\frac{1}{2 \beta}\left\|w^t-w_i\right\|^2\right\} wi,Kt=wiargmin{Fi−λi⊤wi+2β1∥wt−wi∥2}
其中 λ i = λ i − 1 β ( w i , K t − w t ) \lambda_i=\lambda_i-\frac{1}{\beta}\left(w_{i, K}^t-w^t\right) λi=λi−β1(wi,Kt−wt)
FedDyn链接:Link -
Algorithm Flow
-
Summary&Rethinking
- 事实上 1 m ∑ i = 1 m ∇ f i ( w t , ξ i ) ≠ ∇ f ( w t , ξ ) w h e r e ξ i ∼ D i ξ ∼ ∑ i = 1 M D i \begin{array}{r} \frac{1}{m} \sum_{i=1}^m \nabla f_i\left(w_t, \xi_i\right) \neq \nabla f\left(w_t, \xi\right) \quad where \quad \xi_i \sim D_i \quad \xi \sim \sum_{i=1}^M D_i \end{array} m1∑i=1m∇fi(wt,ξi)=∇f(wt,ξ)whereξi∼Diξ∼∑i=1MDi,离global越近一定越好吗
- 现在Stochastic Non-convex setting不推出来个 O ( 1 / T ) \mathcal{O}(1 / T) O(1/T)都不太好说自己Fast
- 从Experiment看,FedSMOO前期收敛较慢是符合直觉的
- FedDyn的引入的必要性没有被很好的展示