出错代码
from torch.optim import optimizer
class ScaffoldOptimizer(optimizer):
def __init__(self, params, lr):
super(ScaffoldOptimizer, self).__init__(params, lr)
self.lr = lr
self.params = params
def step(self, server_controls, client_controls):
for k, v in self.params:
# w = w - lr * (w.grad + c - ci)
v.data = v.data - self.lr * (v.grad.data + server_controls[k] - client_controls[k])
解决
应该继承Optimizer:
from torch.optim import Optimizer