Scaffold 基于fedavg方法的改进,代码复现(联邦学习)

当前有个工作需要实现scaffold算法,该方法通过添加修正项c来解决客户端漂移现象,
在参考github上的相关框架后,复现了该算法。
算法分为三个模块:
optimizer: 重写优化器sdg
clientscaffold:客户端操作
serverscaffold:服务端操作

optimizer部分代码:

import torch
from torch.optim import Optimizer

class SCAFFOLDOptimizer(Optimizer):
    def __init__(self, params, lr, weight_decay):
        defaults = dict(lr=lr, weight_decay=weight_decay)
        super(SCAFFOLDOptimizer, self).__init__(params, defaults)
        pass

    def step(self, server_controls, client_controls, closure=None):
        loss = None
        if closure is not None:
            loss = closure

        # for group, c, ci in zip(self.param_groups, server_controls, client_controls):
        #     p = group['params'][0]
        #     if p.grad is None:
        #         continue
        #     d_p = p.grad.data + c.data - ci.data
        #     p.data = p.data - d_p.data * group['lr']
        for group in self.param_groups:
            for p, c, ci in zip(group['params'], server_controls, client_controls):
                if p.grad is None:
                    continue
                d_p = p.grad.data + c.data - ci.data #这里实现用c来更新本地模型
                p.data = p.data - d_p.data * group['lr']
        return loss

serverscaffold:

from flcore.clients.clientscaffold import clientScaffold
from flcore.servers.serverbase import Server
from utils.data_utils import read_client_data
from threading import Thread
import torch
import random

class Scaffold(Server):
    def __init__(self, device, dataset, algorithm, model, batch_size, learning_rate, global_rounds, local_steps, join_clients,
                 num_clients, times, eval_gap, client_drop_rate, train_slow_rate, send_slow_rate, time_select, goal, time_threthold):
        super().__init__(dataset, algorithm, model, batch_size, learning_rate, global_rounds, local_steps, join_clients,
                         num_clients, times, eval_gap, client_drop_rate, train_slow_rate, send_slow_rate, time_select, goal, 
                         time_threthold)
        # select slow clients
        self.set_slow_clients()

        self.global_model=model
        for i, train_slow, send_slow in zip(range(self.num_clients), self.train_slow_clients, self
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值