FedAvg论文

论文:Communication-Efficient Learning of Deep Networks
from Decentralized Data
原code
Reproducing

通过阅读帖子进行的了解。

联邦平均算法就是最典型的平均算法之一。将每个客户端上的本地随机梯度下降和执行模型的平均服务器结合在一起。

联邦优化问题

  • 数据非独立同分布

  • 数据分布的不平衡性

  • 用户规模大

  • 通信有限

联邦平均算法

客户端与服务器之间的通信代价比较大,文中提出两种方法降低通信成本:

  • 增加并行性

  • 增加每个客户端计算量

首先提出FedSGD算法,本地执行多次FedSGD,得到FedAvg算法。

  • 选择一定比例客户端参与训练,而不是全部,因为全部的会比客户端的收敛速度慢,模型精度低

  • 该算法将计算量放在了本地客户端,服务器只用于聚合平均,可在平均步骤之前进行多次局部模型的更新,过多的本地迭代轮次会造成过拟合

代码复现

IID、Non-IID的含义:

  • 数据独立同分布,IID,Independent Identically Distribution,数据之间不相互影响,满足同一个分布。

    独立同分布数据,说明训练的样本点具有较好的总体代表性。

  • 非独立同分布,Non-IID,Non-Independent Identically Distribution,实际场景数据很难满足IID的前提假设。

依照帖子对代码文件的介绍,如下图所示:

在这里插入图片描述

我的本地电脑:如下图所示:

在这里插入图片描述

main_fed.py

"""
FileName: 
Author: 
Version: 
Date: 2024/6/1017:27
Description: 
"""
# 导入工具包
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import copy
import numpy as np
from torchvision import datasets, transforms
import torch


from utils.sampling import mnist_iid, mnist_noniid, cifar_iid
from utils.options import args_parser
from models.Update import LocalUpdate
from models.Nets import MLP, CNNMnist, CNNCifar
from models.Fed import FedAvg
from models.test import test_img

# main函数

if __name__ == '__main__':
    # parse args
    args = args_parser()
    args.device = torch.device('cuda:{}'.format(args.gpu) if torch.cuda.is_available() and args.gpu != -1 else 'cpu')

    # load dataset and split users
    if args.dataset == 'mnist':
        # mnist数据集,将图片转为tentor,并进行归一化处理
        trans_mnist = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
        # 调用datasets库,下载训练、测试数据集
        dataset_train = datasets.MNIST('../data/mnist/', train=True, download=True, transform=trans_mnist)
        dataset_test = datasets.MNIST('../data/mnist/', train=False, download=True, transform=trans_mnist)

        # sample users
        if args.iid: # iid
            dict_users = mnist_iid(dataset_train, args.num_users)
        else: # non-iid
            dict_users = mnist_noniid(dataset_train, args.num_users)

    elif args.dataset == 'cifar':
        trans_cifar = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))])
        #
        dataset_train = datasets.CIFAR10('../data/cifar', train=True, download=True, transform=trans_cifar)
        dataset_test = datasets.CIFAR10('../data/cifar', train=False, download=True, transform=trans_cifar)

        if args.iid:
            dict_users = cifar_iid(dataset_train, args.num_users)
        else:
            exit('Error: only consider IID setting in CIFAR10')
    else:
        exit('Error: unrecognized dataset')
    img_size = dataset_train[0][0].shape

    # build model
    # net_glob全局模型?
    if args.model == 'cnn' and args.dataset == 'cifar':
        net_glob = CNNCifar(args=args).to(args.device)
    elif args.model == 'cnn' and args.dataset == 'mnist':
        net_glob = CNNMnist(args=args).to(args.device)
    elif args.model == 'mlp':
        len_in = 1
        for x in img_size:
            len_in *= x
        net_glob = MLP(dim_in=len_in, dim_hidden=200, dim_out=args.num_classes).to(args.device)
    else:
        exit('Error: unrecognized model')
    print(net_glob) # 打印具体网络结构
    net_glob.train()

    # copy weights 复制权重
    w_glob = net_glob.state_dict()
    # training,fedavg核心代码
    loss_train = []
    cv_loss, cv_acc = [], []
    val_loss_pre, counter = 0, 0 # 预测损失,计数器
    net_best = None
    best_loss = None
    val_acc_list, net_list = [], []

    if args.all_clients:
        print("Aggregation over all clients")
        # 给参与训练的局部下发全局初始模型
        w_locals = [w_glob for i in range(args.num_users)]
    for iter in range(args.epochs): # 局部迭代轮次
        loss_locals = [] # 局部预测损失
        if not args.all_clients:
            w_locals = []
        m = max(int(args.frac * args.num_users), 1) # 每轮被选参与联邦学习的用户比例frac
        # sample client
        idxs_users = np.random.choice(range(args.num_users), m, replace=False)
        for idx in idxs_users:
            # 局部模型的训练
            # 依据用户id,获取划分得到的用户id数据集部分,进行局部训练
            local = LocalUpdate(args=args, dataset=dataset_train, idxs=dict_users[idx])
            w, loss = local.train(net=copy.deepcopy(net_glob).to(args.device))
            if args.all_clients:
                w_locals[idx] = copy.deepcopy(w)
            else:
                w_locals.append(copy.deepcopy(w)) # w_locals汇总本地权重
            loss_locals.append(copy.deepcopy(loss)) # 局部损失以列表形式往后添加
        # 全局更新update global weights
        w_glob = FedAvg(w_locals)
        # copy weight to net_glob
        net_glob.load_state_dict(w_glob) # 复制权重,以便下次

        # print loss
        loss_avg = sum(loss_locals) / len(loss_locals)
    # plot loss curve
    plt.figure()
    plt.plot(range(len(loss_train)), loss_train)
    plt.ylabel('train_loss')
    # plt.show()
    # testing
    net_glob.eval() #eavl()函数 关闭batch normalization与dropout 处理
    acc_train, loss_train = test_img(net_glob, dataset_train, args)
    acc_test, loss_test = test_img(net_glob, dataset_test, args)
    print("Training accuracy: {:.2f}".format(acc_train))
    print("Testing accuracy: {:.2f}".format(acc_test))


Fed.py

关键原理:Fed.py中的权重平均聚合算法,

def FedAvg(w):
    '''

    :param w: 权重吗?是的,是包含多个用户端模型权重的列表,每个权重相当于一个字典,带有键值
    :return:
    '''
    w_avg = copy.deepcopy(w[0]) # 利用深拷贝获取初始w[0]
    for k in w_avg.keys(): # 遍历每个权重键
        for i in range(1, len(w)):
            w_avg[k] += w[i][k] # 累加
        w_avg[k] = torch.div(w_avg[k], len(w)) # 平均
    return w_avg

Update.py

"""
FileName: 
Author: 
Version: 
Date: 2024/6/1021:11
Description: 
"""
import torch
from torch import nn, autograd
from torch.utils.data import DataLoader, Dataset
import numpy as np
import random
from sklearn import metrics


class DatasetSplit(Dataset):
    def __init__(self, dataset, idxs):
        self.dataset = dataset
        self.idxs = list(idxs)

    def __len__(self):
        return len(self.idxs)

    def __getitem__(self, item):
        image, label = self.dataset[self.idxs[item]]
        return image, label


class LocalUpdate(object):
    def __init__(self, args, dataset=None, idxs=None):
        self.args = args
        self.loss_func = nn.CrossEntropyLoss()
        self.selected_clients = []
        self.ldr_train = DataLoader(DatasetSplit(dataset, idxs), batch_size=self.args.local_bs, shuffle=True)

    def train(self, net):
        net.train()
        # train and update
        optimizer = torch.optim.SGD(net.parameters(), lr=self.args.lr, momentum=self.args.momentum)

        epoch_loss = []
        for iter in range(self.args.local_ep):
            batch_loss = []
            for batch_idx, (images, labels) in enumerate(self.ldr_train):
                images, labels = images.to(self.args.device), labels.to(self.args.device)
                net.zero_grad()
                log_probs = net(images)
                loss = self.loss_func(log_probs, labels)
                loss.backward()
                optimizer.step()
                if self.args.verbose and batch_idx % 10 == 0:
                    print('Update Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                        iter, batch_idx * len(images), len(self.ldr_train.dataset),
                               100. * batch_idx / len(self.ldr_train), loss.item()))
                batch_loss.append(loss.item())
            epoch_loss.append(sum(batch_loss)/len(batch_loss))
        return net.state_dict(), sum(epoch_loss) / len(epoch_loss)


参考

【FedAvg论文笔记】&【代码复现】

  • 18
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值