当前有个工作需要实现scaffold算法,该方法通过添加修正项c来解决客户端漂移现象,
在参考github上的相关框架后,复现了该算法。
算法分为三个模块:
optimizer: 重写优化器sdg
clientscaffold:客户端操作
serverscaffold:服务端操作
optimizer部分代码:
import torch
from torch.optim import Optimizer
class SCAFFOLDOptimizer(Optimizer):
def __init__(self, params, lr, weight_decay):
defaults = dict(lr=lr, weight_decay=weight_decay)
super(SCAFFOLDOptimizer, self).__init__(params, defaults)
pass
def step(self, server_controls, client_controls, closure=None):
loss = None
if closure is not None:
loss = closure
# for group, c, ci in zip(self.param_groups, server_controls, client_controls):
# p = group['params'][0]
# if p.grad is None:
# continue
# d_p = p.grad.data + c.data - ci.data
# p.data = p.data - d_p.data * group['lr']
for group in self.param_groups:
for p, c, ci in zip(group['params'], server_controls, client_controls):
if p.grad is None:
continue
d_p = p.grad.data + c.data - ci.data #这里实现用c来更新本地模型
p.data = p.data - d_p.data * group['lr']
return loss
serverscaffold:
from flcore.clients.clientscaffold import clientScaffold
from flcore.servers.serverbase import Server
from utils.data_utils import read_client_data
from threading import Thread
import torch
import random
class Scaffold(Server):
def __init__(self, device, dataset, algorithm, model, batch_size, learning_rate, global_rounds, local_steps, join_clients,
num_clients, times, eval_gap, client_drop_rate, train_slow_rate, send_slow_rate, time_select, goal, time_threthold):
super().__init__(dataset, algorithm, model, batch_size, learning_rate, global_rounds, local_steps, join_clients,
num_clients, times, eval_gap, client_drop_rate, train_slow_rate, send_slow_rate, time_select, goal,
time_threthold)
# select slow clients
self.set_slow_clients()
self.global_model=model
for i, train_slow, send_slow in zip(range(self.num_clients), self.train_slow_clients, self