联邦学习公式推导--为何只需要发送模型参数而不是模型梯度?

参考文献https://arxiv.org/pdf/1602.05629v4.pdf

  • 对于一个机器学习应用来说,我们需要找到一个目标函数,使其最小化
    f ( w ) = 1 n ∑ i = 1 n f i ( w ) f(w) = \frac{1}{n}\sum_{i=1}^{n}f_i(w) f(w)=n1i=1nfi(w)

  • 上面等式中, f i ( w ) = l ( x i , y i , w ) f_i(w) = l(x_i, y_i, w) fi(w)=l(xi,yi,w),表示参数为 w w w的模型在样本 ( x i , y i ) (x_i, y_i) (xi,yi)上预测的损失。

  • 假设现在有多个设备并行计算模型在某个数据集上的总体预测损失,总样本数为 n n n,设备 k k k上有 n k n_k nk个样本,那么设备 k k k上的损失为:
    F k ( w ) = 1 n k ∑ i = 1 n k f k ( w ) F_{k}(w) = \frac{1}{n_k}\sum_{i=1}^{n_k}f_k(w) Fk(w)=nk1i=1nkfk(w)

  • 那么模型在整个数据集上的预测损失为:
    f ( w ) = ∑ k = 1 K n k n F k ( w ) = 1 n ∑ k = 1 K ∑ i = 1 n k f i ( w ) f(w) = \sum_{k=1}^K \frac{n_k}{n} F_k(w) = \frac{1}{n}\sum_{k=1}^K\sum_{i=1}^{n_k}f_i(w) f(w)=k=1KnnkFk(w)=n1k=1Ki=1nkfi(w)

  • 相当于使用同一个模型在整个数据集上跑了一遍,得到了总体的平均损失

  • 为了减少client和server间的通信次数,可以让更多的计算在client上完成。

    • 假设第 t t t轮通信中设备k进行一次梯度下降得到的梯度为: g k = ∇ F k ( w t ) gk=\nabla F_k(w_t) gk=Fk(wt) g k g_k gk最终会发送到服务器。 w t w_t wt是第 t t t轮通信中模型的参数。

    • 根据求导法则可知: ∇ f ( w t ) = ∑ k = 1 K n k n ∇ F k ( w t ) \nabla f(w_t) = \sum_{k=1}^K \frac{n_k}{n} \nabla F_k(w_t) f(wt)=k=1KnnkFk(wt),所以服务器拿到所以client的参数之后,更新下一轮模型的参数为:
      w t + 1 = w t − α ∇ f ( w t ) = w t − α ∑ k = 1 K n k n g k w_{t+1} = w_t - \alpha \nabla f(w_t) = w_t - \alpha \sum_{k=1}^{K} \frac{n_k}{n} g_k wt+1=wtαf(wt)=wtαk=1Knnkgk

    • 又因为设备 k k k可以用局部数据更新参数:
      w t + 1 k = w t − α g k α g k = w t − w t + 1 k w^k_{t+1} = w_t - \alpha g_k \\ \alpha g_k = w_t - w^k_{t+1} wt+1k=wtαgkαgk=wtwt+1k

    • 代入上面公式:
      w t + 1 = w t − ∑ k = 1 K n k n ( w t − w t + 1 k ) = w t − ∑ k = 1 K n k n w t + ∑ k = 1 K n k n w t + 1 k w t + 1 = ∑ k = 1 K n k n w t + 1 k w_{t+1} = w_t - \sum_{k=1}^K \frac{n_k}{n} (w_t - w^k_{t+1}) = w_t - \frac{\sum_{k=1}^K n_k}{n}w_t + \sum_{k=1}^{K}\frac{n_k}{n}w^k_{t+1} \\ w_{t+1} = \sum_{k=1}^{K}\frac{n_k}{n}w^k_{t+1} wt+1=wtk=1Knnk(wtwt+1k)=wtnk=1Knkwt+k=1Knnkwt+1kwt+1=k=1Knnkwt+1k

    • 上面是进行一次梯度下降,如果进行多次梯度下降,设备 k k k的更新参数公式为:
      w t + 1 k = w t − α g k 1 − α g k 2    − . . . −    α g k e p o c h   其中 g k i 表示第 i 次梯度下降 w t + 1 k = w t − α ( g k 1 + g k 2    + . . . + g k e p o c h ) 令    g k = g k 1 + g k 2    + . . . + g k e p o c h ,   则有    w t + 1 k = w t − α g k w_{t+1}^k = w_t - \alpha g_k^1 - \alpha g_k^2 \ \ - ...-\ \ \alpha g_k^{epoch} \ \ 其中 g_k^i表示第i次梯度下降 \\ w_{t+1}^k = w_t - \alpha (g_k^1 + g_k^2 \ \ + ...+g_k^{epoch}) \\ 令 \ \ \ g_k = g_k^1 + g_k^2 \ \ + ...+g_k^{epoch}, \ \ 则有 \ \ \ w_{t+1}^k = w_t-\alpha g_k wt+1k=wtαgk1αgk2  ...  αgkepoch  其中gki表示第i次梯度下降wt+1k=wtα(gk1+gk2  +...+gkepoch)   gk=gk1+gk2  +...+gkepoch,  则有   wt+1k=wtαgk

    • 所以可以在本地进行多次梯度下降并更新本地模型参数,然后将本地模型参数发送给服务器,服务器对这些参数进行加权平均得到全局模型参数,最终发送给各个设备,这样就能减少客户端和服务器间的通信次数。

  • 26
    点赞
  • 20
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值