【联邦学习】联邦平均(FedAvg)_附pytorch代码实现

联邦学习

代码可在https://github.com/kt4ngw/fedavg-pytorch 找到。

1 简介

联邦学习是一个分布式的训练架构,一般来说其拥有一个中心服务器和多个客户端组成。不同于集中式训练,这些客户端在本地数据集上训练模型,只需要上传其模型参数到服务器进行聚合,而不用上传原始数据,这保护用户的隐私。

联邦学习算法的步骤如下所示:
假设有 N N N个客户端

1)中心服务器初始化其模型 w s w_{s} ws,然后随机挑选 C ∗ N C*N CN个客户端,将服务器的全局模型 w s w_{s} ws传输给挑选的客户端;

2)这些客户端根据全局模型 w s w_{s} ws在其本地训练集上训练得到自己的模型参数 w k w_{k} wk

3)这些被选中的客户端上传其模型至中心服务器进行全局聚合,得到新的全局模型 w t n e w w_{t}^{new} wtnew

4)一直循环1)2)3)直到全局模型达到规定的精度或到达预设的轮数。

2 客户端

客户端的职责是1)训练本地模型;2)上传模型

客户端的代码定义如下:



class BaseClient():
    def __init__(self, options, id, local_dataset, model, optimizer, ):
        self.options = options
        self.id = id
        self.local_dataset = local_dataset
        self.model = model
        self.gpu = options['gpu']
        self.optimizer = optimizer

        """
        author:kt4ngw
        mail:kt4ngw@163.com
        links:https://github.com/kt4ngw
        """
    def get_model_parameters(self):
        state_dict = self.model.state_dict()
        return state_dict

    def set_model_parameters(self, model_parameters_dict):
        state_dict = self.model.state_dict()
        for key, value in state_dict.items():
            state_dict[key] = model_parameters_dict[key]
        self.model.load_state_dict(state_dict)

    def local_train(self, ):
        begin_time = time.time()
        local_model_paras, dict = self.local_update(self.local_dataset, self.options, )
        end_time = time.time()
        stats = {'id': self.id, "time": round(end_time - begin_time, 2)}
        stats.update(dict)
        return (len(self.local_dataset), local_model_paras), stats

    def local_update(self, local_dataset, options, ):
        localTrainDataLoader = DataLoader(local_dataset, batch_size=options['batch_size'], shuffle=True)
        self.model.train()
        # print(self.optimizer.param_groups[0]['lr'])
        train_loss = train_acc = train_total = 0
        for epoch in range(options['local_epoch']):
            train_loss = train_acc = train_total = 0
            for X, y in localTrainDataLoader:
                if self.gpu:
                    X, y = X.cuda(), y.cuda()
                pred = self.model(X)
                loss = criterion(pred, y)
                loss.backward()
                # print(loss)
                self.optimizer.step()
                self.optimizer.zero_grad()
                _, predicted = torch.max(pred, 1)
                correct = predicted.eq(y).sum().item()
                target_size = y.size(0)
                train_loss += loss.item() * y.size(0)
                train_acc += correct
                train_total += target_size
        local_model_paras = self.get_model_parameters()
        return_dict = {"id": self.id,
                       "loss": train_loss / train_total,
                       "acc": train_acc / train_total}

        return local_model_paras, return_dict

2 中心服务器

中心服务器的职责主要有:1)管理客户端;2)选择客户端;3)聚合模型;4)测试模型等

中心服务器的代码定义如下:

fedavg的训练手段继承该类,然后写一个def train()即可。

import numpy as np
import torch
import time
from src.fed_client.client import BaseClient
from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader
import copy
from src.utils.metrics import Metrics
import torch.nn.functional as F
criterion = F.cross_entropy


class BaseFederated(object):
    def __init__(self, options, dataset, clients_label, model=None, optimizer=None, name=''):
        """
        author:kt4ngw
        mail:kt4ngw@163.com
        links:https://github.com/kt4ngw
        """
        if model is not None and optimizer is not None:
            self.model = model
            self.optimizer = optimizer
        self.options = options
        self.dataset = dataset
        self.clients_label = clients_label
        self.gpu = options['gpu']
        self.batch_size = options['batch_size']
        self.num_round = options['round_num']
        self.per_round_c_fraction = options['c_fraction']
        self.clients = self.setup_clients(self.dataset, self.clients_label)
        self.clients_num = len(self.clients)
        self.name = '_'.join([name, f'wn{int(self.per_round_c_fraction * self.clients_num)}',
                              f'tn{len(self.clients)}'])
        self.metrics = Metrics(options, self.clients, self.name)
        self.latest_global_model = self.get_model_parameters()


    @staticmethod
    def move_model_to_gpu(model, options):
        if options['gpu'] is True:
            device = 0
            torch.cuda.set_device(device)
            # torch.backends.cudnn.enabled = True
            model.cuda()
            print('>>> Use gpu on device {}'.format(device))
        else:
            print('>>> Don not use gpu')

    def get_model_parameters(self):
        state_dict = self.model.state_dict()
        return state_dict

    def set_model_parameters(self, model_parameters_dict):
        state_dict = self.model.state_dict()
        for key, value in state_dict.items():
            state_dict[key] = model_parameters_dict[key]
        self.model.load_state_dict(state_dict)

    def train(self):
        """The whole training procedure

        No returns. All results all be saved.
        """
        raise NotImplementedError

    def setup_clients(self, dataset, clients_label):
        train_data = dataset.train_data
        train_label = dataset.train_label
        all_client = []
        for i in range(len(clients_label)):
            local_client = BaseClient(self.options, i, TensorDataset(torch.tensor(train_data[self.clients_label[i]]),
                                                torch.tensor(train_label[self.clients_label[i]])), self.model, self.optimizer)
            all_client.append(local_client)

        return all_client

    def local_train(self, round_i, select_clients, ):

        local_model_paras_set = []
        stats = []
        for i, client in enumerate(select_clients, start=1):
            client.set_model_parameters(self.latest_global_model)
            local_model_paras, stat = client.local_train()
            local_model_paras_set.append(local_model_paras)
            stats.append(stat)
            if True:
                print("Round: {:>2d} | CID: {: >3d} ({:>2d}/{:>2d})| "
                      "Loss {:>.4f} | Acc {:>5.2f}% | Time: {:>.2f}s ".format(
                       round_i, client.id, i, int(self.per_round_c_fraction * self.clients_num),
                       stat['loss'], stat['acc'] * 100, stat['time'], ))
        return local_model_paras_set, stats



    def aggregate_parameters(self, local_model_paras_set):

        averaged_paras = copy.deepcopy(self.model.state_dict())
        train_data_num = 0
        for var in averaged_paras:
            averaged_paras[var] = 0
        for num_sample, local_model_paras in local_model_paras_set:
            for var in averaged_paras:
                averaged_paras[var] += num_sample * local_model_paras[var]
            train_data_num += num_sample
        for var in averaged_paras:
            averaged_paras[var] /= train_data_num
        return averaged_paras



    def test_latest_model_on_testdata(self, round_i):
        # Collect stats from total test data
        begin_time = time.time()
        stats_from_test_data = self.global_test(use_test_data=True)
        end_time = time.time()

        if True:
            print('= Test = round: {} / acc: {:.3%} / '
                  'loss: {:.4f} / Time: {:.2f}s'.format(
                   round_i, stats_from_test_data['acc'],
                   stats_from_test_data['loss'], end_time-begin_time))
            print('=' * 102 + "\n")

        self.metrics.update_test_stats(round_i, stats_from_test_data)

    def global_test(self, use_test_data=True):
        assert self.latest_global_model is not None
        self.set_model_parameters(self.latest_global_model)
        test_data = self.dataset.test_data
        test_label = self.dataset.test_label
        print("testLabel", test_label)
        testDataLoader = DataLoader(TensorDataset(torch.tensor(test_data), torch.tensor(test_label)), batch_size=10, shuffle=False)
        test_loss = test_acc = test_total = 0.
        with torch.no_grad():
            for X, y in testDataLoader:
                if self.gpu:
                    X, y = X.cuda(), y.cuda()
                    # @方攵纟从 感谢纠正CPU版本
                pred = self.model(X)
                loss = criterion(pred, y)
                _, predicted = torch.max(pred, 1)

                correct = predicted.eq(y).sum()
                test_acc += correct.item()
                test_loss += loss.item() * y.size(0)
                test_total += y.size(0)

        stats = {'acc': test_acc / test_total,
                 'loss': test_loss / test_total,
                 'num_samples': test_total,}
        return stats

4 写在最后

其他代码包括1)数据集划分;2)训练;3)保存结果;4)绘图
都可在https://github.com/kt4ngw/fedavg-pytorch 找到。

如果对您有帮助,欢迎及谢谢您star一下。

如果您对文章有建议及指定,也欢迎您留言评论,作者定当虚心修正。

评论 10
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值