FedAvg代码详解

FedAvg代码详解

代码位置:https://github.com/shaoxiongji/federated-learning

README

首先查看作者给的README文档:

在这里插入图片描述

这一部分给出了论文的地址,如果有兴趣的话可以读一下论文,讲的就是FedAvg的思想以及对独立同分布数据和非独立同分布数据的一些研究。

下来是requirements,就是代码的依赖库,需要安装

在这里插入图片描述

如果有不会安装pytorch的同学可以看我之前的博客,从安装CUDA和CUDNN开始的。

运行:

在这里插入图片描述

给出了运行的命令,如果使用MLP和CNN模型单独训练就执行main_nn.py文件,使用联邦学习的训练就执行main_fed.py文件。并且给出了命令行参数:–dataset设置训练集;–iid设置数据是否为独立同分布;以及–num_channels设置数据的通道数量,如果是MNIST数据集就是1,CIFAR-10是3;–epochs设置训练的轮数;–gpu设置是否使用GPU;–all_client设置平局所有客户端模型。

作者给出的结果:

在这里插入图片描述

项目目录

在这里插入图片描述

  • data文件夹下存放数据集,有MNIST和CIFAR-10
  • models文件夹下存放和模型相关的文件
  • save存放训练的结果
  • utils存放一些工具

分析代码

所有_init_.py文件都不需要分析

Models目录下代码

Fed.py
import copy
import torch
from torch import nn


def FedAvg(w):
    w_avg = copy.deepcopy(w[0])  # 创建一个深拷贝,用于平均参数的累加
    for k in w_avg.keys():  # 遍历参数字典的键
        for i in range(1, len(w)):  # 遍历参与平均的参数列表
            w_avg[k] += w[i][k]  # 将每个参数按键累加到平均参数中
        w_avg[k] = torch.div(w_avg[k], len(w))  # 将累加的参数值除以参与平均的参数个数,得到平均值
    return w_avg  # 返回平均参数

Fed文件主要做的事情就是对权重求平均

这段代码可能难以理解的就是这个深拷贝,我这里给出一个测试代码,大家可以自行理解,这里就不展开了,有问题可以写在评论区或者私信我。

import copy

a = 1   # =赋值 不可变元素如字符串、数值等
b = a
print('a和b的值:', a, b)
print('a和b的id:', id(a), id(b))
a = 2
print('修改a后a和b的值:', a, b)
print('修改a后a和b的id:', id(a), id(b))
# a和b的值: 1 1
# a和b的id: 2279821764912 2279821764912
# 修改a后a和b的值: 2 1
# 修改a后a和b的id: 2279821764944 2279821764912

c = [1,2,3] # =赋值 可变元素如列表、字典等
d = c
print('c和d的值:', c, d)
print('c和d的id:', id(c), id(d))
c.append(4)
print('修改c后c和d的值:', c, d)
print('修改c后c和d的id:', id(c), id(d))
# c和d的值: [1, 2, 3] [1, 2, 3]
# c和d的id: 2279827005184 2279827005184
# 修改c后c和d的值: [1, 2, 3, 4] [1, 2, 3, 4]
# 修改c后c和d的id: 2279827005184 2279827005184

orignal_list = [1, 2, 3, [4]]   # 定义一个有嵌套层次的列表
copy_list = copy.copy(orignal_list) # 浅拷贝
deepcopy_list = copy.deepcopy(orignal_list) # 深拷贝
print(copy_list, deepcopy_list)
print(id(orignal_list), id(copy_list), id(deepcopy_list))
orignal_list.append(5)
orignal_list[-2].append(6)  # 修改嵌套内的可变对象
print(orignal_list, copy_list, deepcopy_list)
# [1, 2, 3, [4]] [1, 2, 3, [4]]
# 1985362690688 1985365369664 1985365370112
# [1, 2, 3, [4, 6], 5] [1, 2, 3, [4, 6]] [1, 2, 3, [4]]	# copy和deepcopy的差别就是:copy只复制了第一层 而更深的层次还是一种引用;deepcopy是递归复制 所有层次完全复制到新的内存空间
Nets.py
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Python version: 3.6

import torch
from torch import nn
import torch.nn.functional as F


class MLP(nn.Module):
    def __init__(self, dim_in, dim_hidden, dim_out):
        super(MLP, self).__init__()
        self.layer_input = nn.Linear(dim_in, dim_hidden)  # 输入层
        self.relu = nn.ReLU()  # ReLU激活函数
        self.dropout = nn.Dropout()  # Dropout层
        self.layer_hidden = nn.Linear(dim_hidden, dim_out)  # 隐藏层

    def forward(self, x):
        x = x.view(-1, x.shape[1] * x.shape[-2] * x.shape[-1])  # 将输入展平
        x = self.layer_input(x)  # 输入层
        x = self.dropout(x)  # Dropout层,用于防止过拟合
        x = self.relu(x)  # ReLU激活函数
        x = self.layer_hidden(x)  # 隐藏层
        return x


class CNNMnist(nn.Module):
    def __init__(self, args):
        super(CNNMnist, self).__init__()
        self.conv1 = nn.Conv2d(args.num_channels, 10, kernel_size=5)  # 第一个卷积层,输入通道数为args.num_channels,输出通道数为10
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)  # 第二个卷积层,输入通道数为10,输出通道数为20
        self.conv2_drop = nn.Dropout2d()  # 二维Dropout层
        self.fc1 = nn.Linear(320, 50)  # 全连接层1,输入大小为320,输出大小为50
        self.fc2 = nn.Linear(50, args.num_classes)  # 全连接层2,输入大小为50,输出大小为args.num_classes

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))  # 第一个卷积层后接ReLU激活函数和最大池化层
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))  # 第二个卷积层后接ReLU激活函数、Dropout层和最大池化层
        x = x.view(-1, x.shape[1]*x.shape[2]*x.shape[3])  # 将输入展平
        x = F.relu(self.fc1(x))  # 全连接层1后接ReLU激活函数
        x = F.dropout(x, training=self.training)  # Dropout层,用于防止过拟合
        x = self.fc2(x)  # 全连接层2,输出最终结果
        return x


class CNNCifar(nn.Module):
    def __init__(self, args):
        super(CNNCifar, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)  # 第一个卷积层,输入通道数为3,输出通道数为6
        self.pool = nn.MaxPool2d(2, 2)  # 最大池化层
        self.conv2 = nn.Conv2d(6, 16, 5)  # 第二个卷积层,输入通道数为6,输出通道数为16
        self.fc1 = nn.Linear(16 * 5 * 5, 120)  # 全连接层1,输入大小为16*5*5,输出大小为120
        self.fc2 = nn.Linear(120, 84)  # 全连接层2,输入大小为120,输出大小为84
        self.fc3 = nn.Linear(84, args.num_classes)  # 全连接层3,输入大小为84,输出大小为args.num_classes

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))  # 第一个卷积层后接ReLU激活函数和最大池化层
        x = self.pool(F.relu(self.conv2(x)))  # 第二个卷积层后接ReLU激活函数和最大池化层
        x = x.view(-1, 16 * 5 * 5)  # 将输入展平
        x = F.relu(self.fc1(x))  # 全连接层1后接ReLU激活函数
        x = F.relu(self.fc2(x))  # 全连接层2后接ReLU激活函数
        x = self.fc3(x)  # 全连接层3,输出最终结果
        return x

Nets.py是定义模型的文件代码,我们在这里实现我们需要使用的模型就可以

test.py
def test_img(net_g, datatest, args):
    net_g.eval()
    # 将模型设置为评估模式 eval函数的作用就是不启用dropout和BN 否则测试的时候不会是训练好的权重

    test_loss = 0
    correct = 0
    data_loader = DataLoader(datatest, batch_size=args.bs)
    # 创建测试集的数据加载器

    for idx, (data, target) in enumerate(data_loader):
        if args.gpu != -1:
            data, target = data.cuda(), target.cuda()
        # 将数据和标签移到GPU上(如果可用)

        log_probs = net_g(data)	# 模型会自动调用forward函数前向传播 对于分类任务 返回值就是每一类的概率 也就是权重w

        test_loss += F.cross_entropy(log_probs, target, reduction='sum').item()
        # 计算批次损失的总和

        y_pred = log_probs.data.max(1, keepdim=True)[1] 
        # 在第一个维度中寻找最大值并不改变维度

        correct += y_pred.eq(target.data.view_as(y_pred)).long().cpu().sum()
        # 计算正确预测的数量

    test_loss /= len(data_loader.dataset)
    # 计算平均损失

    accuracy = 100.00 * correct / len(data_loader.dataset)
    # 计算准确率

    if args.verbose:
        print('\nTest set: Average loss: {:.4f} \nAccuracy: {}/{} ({:.2f}%)\n'.format(
            test_loss, correct, len(data_loader.dataset), accuracy))
    # 打印测试结果(平均损失和准确率)

    return accuracy, test_loss
    # 返回准确率和损失

test.py中只有一个test_img函数,输入模型和数据集,返回准确率和测试的损失

Update.py
class DatasetSplit(Dataset):	
    # Dataset的子类 可以创建一个数据子集对象 只包含原始数据集中特定索引的样本 在分割数据集用于训练和测试时使用
    def __init__(self, dataset, idxs):
        self.dataset = dataset
        self.idxs = list(idxs)

    def __len__(self):
        return len(self.idxs)

    def __getitem__(self, item):
        image, label = self.dataset[self.idxs[item]]
        return image, label

这个类在实例化的时候,接收参数dataset和idxs,数据集和要分割的索引,将数据集按照索引进行分割

class LocalUpdate(object):	# 本地模型更新 根据本地数据集进行训练并返回
    def __init__(self, args, dataset=None, idxs=None):
        self.args = args
        self.loss_func = nn.CrossEntropyLoss()	# 交叉熵损失
        self.selected_clients = []	# 选择的客户端节点
        self.ldr_train = DataLoader(DatasetSplit(dataset, idxs), batch_size=self.args.local_bs, shuffle=True)	
        # 加载数据集的子集 通过idxs分割 本地的batch_size由命令行参数给出

    def train(self, net):	# 本地训练
        net.train()		# 设置为训练模式 会启用dropout函数和BN函数
        # train and update
        optimizer = torch.optim.SGD(net.parameters(), lr=self.args.lr, momentum=self.args.momentum)	# 梯度下降算法使用随机梯度下降

        epoch_loss = []
        for iter in range(self.args.local_ep):	# 根据local_ep确定本地训练轮数
            batch_loss = []
            for batch_idx, (images, labels) in enumerate(self.ldr_train):
                images, labels = images.to(self.args.device), labels.to(self.args.device)
                net.zero_grad()
                log_probs = net(images)
                loss = self.loss_func(log_probs, labels)
                loss.backward()		
                # 反向传播 用于计算损失函数对模型参数的梯度 从而更新模型参数以最小化损失函数
                optimizer.step()	# 根据
                if self.args.verbose and batch_idx % 10 == 0:
                    print('Update Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                        iter, batch_idx * len(images), len(self.ldr_train.dataset),
                               100. * batch_idx / len(self.ldr_train), loss.item()))
                batch_loss.append(loss.item())
            epoch_loss.append(sum(batch_loss)/len(batch_loss))
        return net.state_dict(), sum(epoch_loss) / len(epoch_loss)

LocalUpdate类主要定义了联邦学习中本地更新的过程,每个客户端使用自己的本地模型进行训练,并将更新后的模型参数传回到中央服务器进行聚合,代码中的形式就是return,net.state_dict()就是模型参数,返回值的第二项就是所有轮次的平均损失。

Utils目录下代码

options.py
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Python version: 3.6

import argparse


def args_parser():
    parser = argparse.ArgumentParser()
    # federated arguments
    parser.add_argument('--epochs', type=int, default=10, help="rounds of training")
    parser.add_argument('--num_users', type=int, default=100, help="number of users: K")
    parser.add_argument('--frac', type=float, default=0.1, help="the fraction of clients: C")
    parser.add_argument('--local_ep', type=int, default=5, help="the number of local epochs: E")
    parser.add_argument('--local_bs', type=int, default=10, help="local batch size: B")
    parser.add_argument('--bs', type=int, default=128, help="test batch size")
    parser.add_argument('--lr', type=float, default=0.01, help="learning rate")
    parser.add_argument('--momentum', type=float, default=0.5, help="SGD momentum (default: 0.5)")
    parser.add_argument('--split', type=str, default='user', help="train-test split type, user or sample")

    # model arguments
    parser.add_argument('--model', type=str, default='mlp', help='model name')
    parser.add_argument('--kernel_num', type=int, default=9, help='number of each kind of kernel')
    parser.add_argument('--kernel_sizes', type=str, default='3,4,5',
                        help='comma-separated kernel size to use for convolution')
    parser.add_argument('--norm', type=str, default='batch_norm', help="batch_norm, layer_norm, or None")
    parser.add_argument('--num_filters', type=int, default=32, help="number of filters for conv nets")
    parser.add_argument('--max_pool', type=str, default='True',
                        help="Whether use max pooling rather than strided convolutions")

    # other arguments
    parser.add_argument('--dataset', type=str, default='mnist', help="name of dataset")
    parser.add_argument('--iid', action='store_true', help='whether i.i.d or not')
    parser.add_argument('--num_classes', type=int, default=10, help="number of classes")
    parser.add_argument('--num_channels', type=int, default=3, help="number of channels of imges")
    parser.add_argument('--gpu', type=int, default=0, help="GPU ID, -1 for CPU")
    parser.add_argument('--stopping_rounds', type=int, default=10, help='rounds of early stopping')
    parser.add_argument('--verbose', action='store_true', help='verbose print')
    parser.add_argument('--seed', type=int, default=1, help='random seed (default: 1)')
    parser.add_argument('--all_clients', action='store_true', help='aggregation over all clients')
    args = parser.parse_args()
    return args

这个代码就没什么好说的,命令行的参数含义

sampling.py
def mnist_iid(dataset, num_users):
    """
    Sample I.I.D. client data from MNIST dataset
    :param dataset:
    :param num_users:
    :return: dict of image index
    """
    num_items = int(len(dataset)/num_users)	# 根据数据集长度及客户端数算出需要分成几份
    dict_users, all_idxs = {}, [i for i in range(len(dataset))]	
    # dict_users是一个空字典 all_idxs是一个从0到len(dataset)-1的连续整数的列表
    for i in range(num_users):
        dict_users[i] = set(np.random.choice(all_idxs, num_items, replace=False))	
        # 为每一个客户端随机选择num_items个索引 并且设置不重复选择 然后将这些索引作为一个集合set存在字典中 使用i作为键 这样每个客户端在在 dict_users 字典中都有一个对应的图像索引集合,包含了随机选择的 num_items 个不重复的图像索引。通过这种方式,确保了每个客户端的数据集合起来能够覆盖整个数据集,并且每个客户端的数据量相等。
        all_idxs = list(set(all_idxs) - dict_users[i])	# 每给一个用户分配数据后 将分配过的数据删去
    return dict_users	# 返回一个用户序号-数据集的字典

从MNIST数据集中抽取独立同分布IID的客户端数据,将数据集划分为多个子数据集给每个客户端使用(但是根据这个代码,我发现其实这个函数并没有考虑数据的同分布问题,只是单纯的数量相同)

import numpy as np

def mnist_noniid(dataset, num_users):
    """
    从MNIST数据集中采样非独立同分布(non-I.I.D.)的客户端数据
    :param dataset: MNIST数据集
    :param num_users: 客户端数量
    :return: 包含客户端数据索引的字典
    """
    num_shards, num_imgs = 200, 300  # 数据集划分的分片数和每个分片的图像数量
    idx_shard = [i for i in range(num_shards)]  # 分片索引列表
    dict_users = {i: np.array([], dtype='int64') for i in range(num_users)}  # 存储客户端数据索引的字典
    idxs = np.arange(num_shards * num_imgs)  # 所有图像的索引
    labels = dataset.train_labels.numpy()  # 所有图像的标签

    # 根据标签对索引进行排序
    idxs_labels = np.vstack((idxs, labels))	# 将两个数据idxs和labels按垂直方向堆叠
    idxs_labels = idxs_labels[:, idxs_labels[1, :].argsort()]	# 按照图像标签进行排序
    idxs = idxs_labels[0, :]	# 取出第一行 即排序后的所有图像索引
	
    # 划分并分配数据
    for i in range(num_users):
        rand_set = set(np.random.choice(idx_shard, 2, replace=False))  # 随机选择两个分片
        idx_shard = list(set(idx_shard) - rand_set)  # 从分片索引列表中移除已选择的分片
        for rand in rand_set:
            dict_users[i] = np.concatenate((dict_users[i], idxs[rand * num_imgs:(rand + 1) * num_imgs]), axis=0)
            # 将所选分片的图像索引添加到相应客户端的数据索引列表中

    return dict_users

作用就是采取非独立同分布的客户端数据,不像iid那样均分数据,因为取数据的时候先将数据按照类别排好序之后再进行分配,将数据排好序后,根据随机的分片索引在随机的位置上取连续的数据,这样取到的数据很可能是同一个标签的。

def cifar_iid(dataset, num_users):
    """
    Sample I.I.D. client data from CIFAR10 dataset
    :param dataset:
    :param num_users:
    :return: dict of image index
    """
    num_items = int(len(dataset)/num_users)
    dict_users, all_idxs = {}, [i for i in range(len(dataset))]
    for i in range(num_users):
        dict_users[i] = set(np.random.choice(all_idxs, num_items, replace=False))
        all_idxs = list(set(all_idxs) - dict_users[i])
    return dict_users

和mnist_iid函数完全一致,我也不知道为什么要放两个一模一样的函数。

核心代码

main_fed.py
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Python version: 3.6

import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import copy
import numpy as np
from torchvision import datasets, transforms
import torch

from utils.sampling import mnist_iid, mnist_noniid, cifar_iid
from utils.options import args_parser
from models.Update import LocalUpdate
from models.Nets import MLP, CNNMnist, CNNCifar
from models.Fed import FedAvg
from models.test import test_img


if __name__ == '__main__':
    # parse args
    args = args_parser()	# 解析命令行参数
    args.device = torch.device('cuda:{}'.format(args.gpu) if torch.cuda.is_available() and args.gpu != -1 else 'cpu')

    # load dataset and split users
    if args.dataset == 'mnist':	# 加载MNIST数据集
        trans_mnist = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
        dataset_train = datasets.MNIST('../data/mnist/', train=True, download=True, transform=trans_mnist)
        dataset_test = datasets.MNIST('../data/mnist/', train=False, download=True, transform=trans_mnist)
        # sample users	# 划分数据集给客户端
        if args.iid:
            dict_users = mnist_iid(dataset_train, args.num_users)
        else:
            dict_users = mnist_noniid(dataset_train, args.num_users)
    elif args.dataset == 'cifar':	# 加载CIFAR-10数据集
        trans_cifar = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])	# 因为CIFAR数据集是3通道
        dataset_train = datasets.CIFAR10('../data/cifar', train=True, download=True, transform=trans_cifar)
        dataset_test = datasets.CIFAR10('../data/cifar', train=False, download=True, transform=trans_cifar)
        if args.iid:
            dict_users = cifar_iid(dataset_train, args.num_users)
        else:
            exit('Error: only consider IID setting in CIFAR10')	
            # 代码没有实现CIFAR的非独立同分布
    else:
        exit('Error: unrecognized dataset')
    img_size = dataset_train[0][0].shape	# 输出数据集的图像格式

    # build model	# 调用模型
    if args.model == 'cnn' and args.dataset == 'cifar':
        net_glob = CNNCifar(args=args).to(args.device)
    elif args.model == 'cnn' and args.dataset == 'mnist':
        net_glob = CNNMnist(args=args).to(args.device)
    elif args.model == 'mlp':
        len_in = 1
        for x in img_size:
            len_in *= x
        net_glob = MLP(dim_in=len_in, dim_hidden=200, dim_out=args.num_classes).to(args.device)
    else:
        exit('Error: unrecognized model')
    print(net_glob)	# 输出网络的前向函数的返回值,也就是一个张量
    net_glob.train()	# 设置为训练模式

    # copy weights
    w_glob = net_glob.state_dict()	# 取出模型参数即权重

    # 训练
    loss_train = []
    cv_loss, cv_acc = [], []
    val_loss_pre, counter = 0, 0
    net_best = None
    best_loss = None
    val_acc_list, net_list = [], []

    if args.all_clients: 	# 如果命令行参数指定了全部客户端
        print("Aggregation over all clients")
        w_locals = [w_glob for i in range(args.num_users)]	# 聚合所有的本地模型参数
    for iter in range(args.epochs):
        loss_locals = []
        if not args.all_clients:
            w_locals = []
            
        m = max(int(args.frac * args.num_users), 1)	
        # 计算每轮训练中参与训练的用户的数量
        idxs_users = np.random.choice(range(args.num_users), m, replace=False)
        # 随机选择m个不重复的用户索引
        for idx in idxs_users:
            local = LocalUpdate(args=args, dataset=dataset_train, idxs=dict_users[idx])
            # 实例化LocalUpdate类
            w, loss = local.train(net=copy.deepcopy(net_glob).to(args.device))
            # 调用LocalUpdate类的函数train 使用深拷贝取出刚才的模型 这里就因为每个客户端都需要一个单独的模型来训练 所以必须使用深拷贝
            if args.all_clients:
                w_locals[idx] = copy.deepcopy(w)
            else:
                w_locals.append(copy.deepcopy(w))
            loss_locals.append(copy.deepcopy(loss))	# 计算本地损失
        # update global weights
        w_glob = FedAvg(w_locals)	# 使用FedAvg算法聚合模型参数

        # copy weight to net_glob	
        net_glob.load_state_dict(w_glob)	# 加载全局模型

        # print loss
        loss_avg = sum(loss_locals) / len(loss_locals)	# 输出平均损失
        print('Round {:3d}, Average loss {:.3f}'.format(iter, loss_avg))
        loss_train.append(loss_avg)

    # plot loss curve
    plt.figure()	#使用plt可视化损失变化
    plt.plot(range(len(loss_train)), loss_train)
    plt.ylabel('train_loss')
    plt.savefig('./save/fed_{}_{}_{}_C{}_iid{}.png'.format(args.dataset, args.model, args.epochs, args.frac, args.iid))

    # testing
    net_glob.eval()	# 测试模型 设置为eval评估模式
    acc_train, loss_train = test_img(net_glob, dataset_train, args) # 调用test_img测试函数
    acc_test, loss_test = test_img(net_glob, dataset_test, args)
    print("Training accuracy: {:.2f}".format(acc_train))
    print("Testing accuracy: {:.2f}".format(acc_test))
main_nn.py
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Python version: 3.6

import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt

import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torch.optim as optim
from torchvision import datasets, transforms

from utils.options import args_parser
from models.Nets import MLP, CNNMnist, CNNCifar


def test(net_g, data_loader):	
    # 计算测试时的平均损失和准确率 实际上和test.py中的test_img代码基本相同
    # testing
    net_g.eval()
    test_loss = 0
    correct = 0
    l = len(data_loader)
    for idx, (data, target) in enumerate(data_loader):
        data, target = data.to(args.device), target.to(args.device)
        log_probs = net_g(data)
        test_loss += F.cross_entropy(log_probs, target).item()
        y_pred = log_probs.data.max(1, keepdim=True)[1]
        correct += y_pred.eq(target.data.view_as(y_pred)).long().cpu().sum()

    test_loss /= len(data_loader.dataset)
    print('\nTest set: Average loss: {:.4f} \nAccuracy: {}/{} ({:.2f}%)\n'.format(
        test_loss, correct, len(data_loader.dataset),
        100. * correct / len(data_loader.dataset)))

    return correct, test_loss


if __name__ == '__main__':
    # parse args
    args = args_parser()	# 解析命令行参数
    args.device = torch.device('cuda:{}'.format(args.gpu) if torch.cuda.is_available() and args.gpu != -1 else 'cpu')

    torch.manual_seed(args.seed)

    # load dataset and split users	# 这一部分加载数据和main_fed一致
    if args.dataset == 'mnist':
        dataset_train = datasets.MNIST('./data/mnist/', train=True, download=True,
                   transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ]))
        img_size = dataset_train[0][0].shape
    elif args.dataset == 'cifar':
        transform = transforms.Compose(
            [transforms.ToTensor(),
             transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
        dataset_train = datasets.CIFAR10('./data/cifar', train=True, transform=transform, target_transform=None, download=True)
        img_size = dataset_train[0][0].shape
    else:
        exit('Error: unrecognized dataset')

    # build model
    if args.model == 'cnn' and args.dataset == 'cifar':
        net_glob = CNNCifar(args=args).to(args.device)
    elif args.model == 'cnn' and args.dataset == 'mnist':
        net_glob = CNNMnist(args=args).to(args.device)
    elif args.model == 'mlp':
        len_in = 1
        for x in img_size:
            len_in *= x
        net_glob = MLP(dim_in=len_in, dim_hidden=64, dim_out=args.num_classes).to(args.device)
    else:
        exit('Error: unrecognized model')
    print(net_glob)

    # training	# 训练过程 直接用整个数据集进行训练
    optimizer = optim.SGD(net_glob.parameters(), lr=args.lr, momentum=args.momentum)
    train_loader = DataLoader(dataset_train, batch_size=64, shuffle=True)

    list_loss = []
    net_glob.train()
    for epoch in range(args.epochs):
        batch_loss = []
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(args.device), target.to(args.device)
            optimizer.zero_grad()
            output = net_glob(data)
            loss = F.cross_entropy(output, target)
            loss.backward()
            optimizer.step()
            if batch_idx % 50 == 0:
                print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                    epoch, batch_idx * len(data), len(train_loader.dataset),
                           100. * batch_idx / len(train_loader), loss.item()))
            batch_loss.append(loss.item())
        loss_avg = sum(batch_loss)/len(batch_loss)
        print('\nTrain loss:', loss_avg)
        list_loss.append(loss_avg)

    # plot loss
    plt.figure()
    plt.plot(range(len(list_loss)), list_loss)
    plt.xlabel('epochs')
    plt.ylabel('train loss')
    plt.savefig('./save/nn_{}_{}_{}.png'.format(args.dataset, args.model, args.epochs))

    # testing	# 调用上面的test函数进行测试
    if args.dataset == 'mnist':
        dataset_test = datasets.MNIST('./data/mnist/', train=False, download=True,
                   transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ]))
        test_loader = DataLoader(dataset_test, batch_size=1000, shuffle=False)
    elif args.dataset == 'cifar':
        transform = transforms.Compose(
            [transforms.ToTensor(),
             transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
        dataset_test = datasets.CIFAR10('./data/cifar', train=False, transform=transform, target_transform=None, download=True)
        test_loader = DataLoader(dataset_test, batch_size=1000, shuffle=False)
    else:
        exit('Error: unrecognized dataset')

    print('test on', len(dataset_test), 'samples')
    test_acc, test_loss = test(net_glob, test_loader)

我们可以看到main_nn就是正常的一个训练过程,main_fed是联邦学习的训练测试过程。代码部分讲解就到这里,我们给出一个自己画的main_fed.py代码流程图。

在这里插入图片描述

  • 23
    点赞
  • 22
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
FedAvg pytorch是一个用于联邦学习的算法。它通过对参与者的本地模型进行加权平均来实现全局模型的更新。这个算法的实现非常简单,它首先对每个参与者的模型参数进行平均,然后将平均参数作为全局模型的更新。具体代码如下所示: def FedAvg(w): w_avg = copy.deepcopy(w += w[i][k w_avg[k = torch.true_divide(w_avg[k], len(w)) return w_avg 其中,w是一个包含参与者模型参数的列表。算法遍历每个参数的键值对,将所有参与者的对应参数加和,并将结果除以参与者的数量,得到平均参数作为全局模型的更新。这样,通过不同参与者的贡献,全局模型可以得到更新并获得更好的性能。<span class="em">1</span><span class="em">2</span><span class="em">3</span> #### 引用[.reference_title] - *1* [联邦学习算法FedAvg实现(PyTorch)](https://blog.csdn.net/Joker_1024/article/details/116377064)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v93^chatsearchT3_2"}}] [.reference_item style="max-width: 33.333333333333336%"] - *2* [联邦元学习算法Per-FedAvg的PyTorch实现](https://blog.csdn.net/Cyril_KI/article/details/123389721)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v93^chatsearchT3_2"}}] [.reference_item style="max-width: 33.333333333333336%"] - *3* [PyTorch 实现联邦学习FedAvg详解)](https://blog.csdn.net/qq_36018871/article/details/121361027)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v93^chatsearchT3_2"}}] [.reference_item style="max-width: 33.333333333333336%"] [ .reference_list ]

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值