Federated Optimization for Heterogeneous Networks, Fedprox

文章介绍了在分布式机器学习中,FedProx算法的ProximalTerm部分,该部分用于保持本地模型与全局模型之间的距离,防止过大的更新偏离。通过计算模型参数差的平方和并乘以惩罚系数mu,得到正则化损失,从而在优化过程中引入稳定性。
摘要由CSDN通过智能技术生成

再水一天

在这里插入图片描述
大道至简呐

代码

	# Add proximal loss term
	loss += self._proximal_term(global_model, self.model, self.algo_params.mu)

    def _proximal_term(self, global_model, model, mu):
        """Proximal regularizer of FedProx"""

        vec = []
        for _, ((name1, param1), (name2, param2)) in enumerate(
            zip(model.named_parameters(), global_model.named_parameters())
        ):
            if name1 != name2:
                raise RuntimeError
            else:
                vec.append((param1 - param2).view(-1, 1))

        all_vec = torch.cat(vec)
        square_term = torch.square(all_vec).sum()
        proximal_loss = 0.5 * mu * square_term

        return proximal_loss
  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值