再水一天
大道至简呐
代码
# Add proximal loss term
loss += self._proximal_term(global_model, self.model, self.algo_params.mu)
def _proximal_term(self, global_model, model, mu):
"""Proximal regularizer of FedProx"""
vec = []
for _, ((name1, param1), (name2, param2)) in enumerate(
zip(model.named_parameters(), global_model.named_parameters())
):
if name1 != name2:
raise RuntimeError
else:
vec.append((param1 - param2).view(-1, 1))
all_vec = torch.cat(vec)
square_term = torch.square(all_vec).sum()
proximal_loss = 0.5 * mu * square_term
return proximal_loss