参考文献:https://arxiv.org/pdf/1602.05629v4.pdf
-
对于一个机器学习应用来说,我们需要找到一个目标函数,使其最小化
f ( w ) = 1 n ∑ i = 1 n f i ( w ) f(w) = \frac{1}{n}\sum_{i=1}^{n}f_i(w) f(w)=n1i=1∑nfi(w) -
上面等式中, f i ( w ) = l ( x i , y i , w ) f_i(w) = l(x_i, y_i, w) fi(w)=l(xi,yi,w),表示参数为 w w w的模型在样本 ( x i , y i ) (x_i, y_i) (xi,yi)上预测的损失。
-
假设现在有多个设备并行计算模型在某个数据集上的总体预测损失,总样本数为 n n n,设备 k k k上有 n k n_k nk个样本,那么设备 k k k上的损失为:
F k ( w ) = 1 n k ∑ i = 1 n k f k ( w ) F_{k}(w) = \frac{1}{n_k}\sum_{i=1}^{n_k}f_k(w) Fk(w)=nk1i=1∑nkfk(w) -
那么模型在整个数据集上的预测损失为:
f ( w ) = ∑ k = 1 K n k n F k ( w ) = 1 n ∑ k = 1 K ∑ i = 1 n k f i ( w ) f(w) = \sum_{k=1}^K \frac{n_k}{n} F_k(w) = \frac{1}{n}\sum_{k=1}^K\sum_{i=1}^{n_k}f_i(w) f(w)=k=1∑KnnkFk(w)=n1k=1∑Ki=1∑nkfi(w) -
相当于使用同一个模型在整个数据集上跑了一遍,得到了总体的平均损失
-
为了减少client和server间的通信次数,可以让更多的计算在client上完成。
-
假设第 t t t轮通信中设备k进行一次梯度下降得到的梯度为: g k = ∇ F k ( w t ) gk=\nabla F_k(w_t) gk=∇Fk(wt), g k g_k gk最终会发送到服务器。 w t w_t wt是第 t t t轮通信中模型的参数。
-
根据求导法则可知: ∇ f ( w t ) = ∑ k = 1 K n k n ∇ F k ( w t ) \nabla f(w_t) = \sum_{k=1}^K \frac{n_k}{n} \nabla F_k(w_t) ∇f(wt)=∑k=1Knnk∇Fk(wt),所以服务器拿到所以client的参数之后,更新下一轮模型的参数为:
w t + 1 = w t − α ∇ f ( w t ) = w t − α ∑ k = 1 K n k n g k w_{t+1} = w_t - \alpha \nabla f(w_t) = w_t - \alpha \sum_{k=1}^{K} \frac{n_k}{n} g_k wt+1=wt−α∇f(wt)=wt−αk=1∑Knnkgk -
又因为设备 k k k可以用局部数据更新参数:
w t + 1 k = w t − α g k α g k = w t − w t + 1 k w^k_{t+1} = w_t - \alpha g_k \\ \alpha g_k = w_t - w^k_{t+1} wt+1k=wt−αgkαgk=wt−wt+1k -
代入上面公式:
w t + 1 = w t − ∑ k = 1 K n k n ( w t − w t + 1 k ) = w t − ∑ k = 1 K n k n w t + ∑ k = 1 K n k n w t + 1 k w t + 1 = ∑ k = 1 K n k n w t + 1 k w_{t+1} = w_t - \sum_{k=1}^K \frac{n_k}{n} (w_t - w^k_{t+1}) = w_t - \frac{\sum_{k=1}^K n_k}{n}w_t + \sum_{k=1}^{K}\frac{n_k}{n}w^k_{t+1} \\ w_{t+1} = \sum_{k=1}^{K}\frac{n_k}{n}w^k_{t+1} wt+1=wt−k=1∑Knnk(wt−wt+1k)=wt−n∑k=1Knkwt+k=1∑Knnkwt+1kwt+1=k=1∑Knnkwt+1k -
上面是进行一次梯度下降,如果进行多次梯度下降,设备 k k k的更新参数公式为:
w t + 1 k = w t − α g k 1 − α g k 2 − . . . − α g k e p o c h 其中 g k i 表示第 i 次梯度下降 w t + 1 k = w t − α ( g k 1 + g k 2 + . . . + g k e p o c h ) 令 g k = g k 1 + g k 2 + . . . + g k e p o c h , 则有 w t + 1 k = w t − α g k w_{t+1}^k = w_t - \alpha g_k^1 - \alpha g_k^2 \ \ - ...-\ \ \alpha g_k^{epoch} \ \ 其中 g_k^i表示第i次梯度下降 \\ w_{t+1}^k = w_t - \alpha (g_k^1 + g_k^2 \ \ + ...+g_k^{epoch}) \\ 令 \ \ \ g_k = g_k^1 + g_k^2 \ \ + ...+g_k^{epoch}, \ \ 则有 \ \ \ w_{t+1}^k = w_t-\alpha g_k wt+1k=wt−αgk1−αgk2 −...− αgkepoch 其中gki表示第i次梯度下降wt+1k=wt−α(gk1+gk2 +...+gkepoch)令 gk=gk1+gk2 +...+gkepoch, 则有 wt+1k=wt−αgk -
所以可以在本地进行多次梯度下降并更新本地模型参数,然后将本地模型参数发送给服务器,服务器对这些参数进行加权平均得到全局模型参数,最终发送给各个设备,这样就能减少客户端和服务器间的通信次数。
-