联邦学习(Federated Learning)pytorch实现,超全注释

本文中,我们先通过pytorch实现传统的集中式训练。废话少说,直接进入重点!

import matplotlib

import matplotlib.pyplot as plt   # 绘图包matplotlib  

import torch

import torch.nn.functional as F

from torch.utils.data import DataLoader

import torch.optim as optim

from torchvision import datasets, transforms  # pytorch的图形库 datasets:加载数据的函数及常用的数据集接口;transforms:常用的图片变换,裁剪、旋转等;

from utils.options import args_parser

from models.Nets import MLP, CNNMnist, CNNCifar

 前两行是引入绘图的包,三四五行是引入torch的包,datasets和transformer的作用已在注释中说明。最后两行是引入的自定义包,在用到的时候会解释。

接下来是主函数:

# parse args   # python自带的命令行参数解析包,读取命令行参数

args = args_parser()   # 读取options.py中的参数信息

args.device = torch.device('cuda:{}'.format(args.gpu) if torch.cuda.is_available() and args.gpu != -1 else 'cpu')  # 使用cpu还是gpu 赋值args.device

torch.manual_seed(args.seed)

开始两行是自定义包args_parser,目的是为了自动解析命令行,获取参数信息,第三行是为了确定cpu还是gpu进行训练,最后的种子其实可以跳过,有没有都行,主要就是为了通过随机种子使得每次的运行结果都一致。

if args.dataset == 'mnist':  # train(bool, optional) 如果为True,则从训练集创建数据集,否则从测试集创建

        dataset_train = datasets.MNIST('./data/mnist/', train=True, download=True,

                   transform=transforms.Compose([

                       transforms.ToTensor(),

                       transforms.Normalize((0.1307,), (0.3081,))

                   ]))  

        # 返回一个两元素组成的元组。第一个元素是PIL.Image.Image对象类型的图像,表示该图像的像素矩阵。第二个元素是一个整数,表示该图像所代表的数字

img_size = dataset_train[0][0].shape  # [1, 28, 28]

    elif args.dataset == 'cifar':

        transform = transforms.Compose(

            [transforms.ToTensor(),

             transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

        dataset_train = datasets.CIFAR10('./data/cifar', train=True, transform=transform, target_transform=None, download=True)

        img_size = dataset_train[0][0].shape  

    else:

        exit('Error: unrecognized dataset')

这段代码的主要作用就是通过pytorch加载数据集,分别是MNIST和CIFAR,通过datasets.MNIST(CIFAR10)可直接加载,train参数表示是训练集还是测试集,transform参数是为了对数据集图片进行处理。

if args.model == 'cnn' and args.dataset == 'cifar':
        net_glob = CNNCifar(args=args).to(args.device)
    elif args.model == 'cnn' and args.dataset == 'mnist':
        net_glob = CNNMnist(args=args).to(args.device)
    elif args.model == 'mlp':
        len_in = 1
        for x in img_size:
            len_in *= x      # len_in为输入维度,在Mnist中为[28,28]
        net_glob = MLP(dim_in=len_in, dim_hidden=64, dim_out=args.num_classes).to(args.device)
    else:
        exit('Error: unrecognized model')
    print(net_glob)

这一部分是为了创建模型,主要还是使用自定义包models.Nets,根据网络和数据集确定合适的模型。

optimizer = optim.SGD(net_glob.parameters(), lr=args.lr, momentum=args.momentum) # model.parameter()模型参数; learning rate学习率; momentum冲量 
    train_loader = DataLoader(dataset_train, batch_size=64, shuffle=True)

    list_loss = []
    net_glob.train()  # 训练开始前使用model.train()
    for epoch in range(args.epochs):
        batch_loss = []
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(args.device), target.to(args.device)
            optimizer.zero_grad()  # 把梯度置零
            output = net_glob(data)
            loss = F.cross_entropy(output, target)  # 计算loss
            loss.backward()  # 反向传播
            optimizer.step()  # 更新所有参数
            if batch_idx % 50 == 0: 
                print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                    epoch, batch_idx * len(data), len(train_loader.dataset),
                           100. * batch_idx / len(train_loader), loss.item()))
            batch_loss.append(loss.item())
        loss_avg = sum(batch_loss)/len(batch_loss)
        print('\nTrain loss:', loss_avg)
        list_loss.append(loss_avg)

接下来就是训练过程了,首先是选择优化器,普遍使用的是SGD,接下来就是训练过程,注释已经给出了每一步的作用,基本训练过程都是固定的。

plt.figure()
    plt.plot(range(len(list_loss)), list_loss)
    plt.xlabel('epochs')
    plt.ylabel('train loss')
    plt.savefig('./log/nn_{}_{}_{}.png'.format(args.dataset, args.model, args.epochs))

这一部分是绘图看效果部分,不详细讲解。

if args.dataset == 'mnist':
        dataset_test = datasets.MNIST('./data/mnist/', train=False, download=True,
                   transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ]))
        test_loader = DataLoader(dataset_test, batch_size=1000, shuffle=False)
    elif args.dataset == 'cifar':
        transform = transforms.Compose(
            [transforms.ToTensor(),
             transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
        dataset_test = datasets.CIFAR10('./data/cifar', train=False, transform=transform, target_transform=None, download=True)
        test_loader = DataLoader(dataset_test, batch_size=1000, shuffle=False)
    else:
        exit('Error: unrecognized dataset')

    print('test on', len(dataset_test), 'samples')
    test_acc, test_loss = test(net_glob, test_loader)

这一部分是测试过程,首先根据数据集对测试集数据进行加载,这部分的主要部分是test函数。test函数代码为:

def test(net_g, data_loader):
    # testing
    net_g.eval()  # 测试开始前使用model.eval()
    test_loss = 0
    correct = 0
    l = len(data_loader)
    for idx, (data, target) in enumerate(data_loader):
        data, target = data.to(args.device), target.to(args.device)
        log_probs = net_g(data)
        test_loss += F.cross_entropy(log_probs, target).item()
        y_pred = log_probs.data.max(1, keepdim=True)[1]   # 最大值的位置信息
        correct += y_pred.eq(target.data.view_as(y_pred)).long().cpu().sum()  # 相同的数量总计

    test_loss /= len(data_loader.dataset)
    print('\nTest set: Average loss: {:.4f} \nAccuracy: {}/{} ({:.2f}%)\n'.format(
        test_loss, correct, len(data_loader.dataset),
        100. * correct / len(data_loader.dataset)))

    return correct, test_loss

过程和训练非常相似。

自定义包的代码options为

import argparse

def args_parser():
    parser = argparse.ArgumentParser()  # 创建一个解析器(Argument Parser()对象)
    # federated arguments     # 添加参数 调用add_argument()方法添加参数
    parser.add_argument('--epochs', type=int, default=10, help="rounds of training")
    parser.add_argument('--num_users', type=int, default=100, help="number of users: K")
    parser.add_argument('--frac', type=float, default=0.1, help="the fraction of clients: C")
    parser.add_argument('--local_ep', type=int, default=5, help="the number of local epochs: E")
    parser.add_argument('--local_bs', type=int, default=10, help="local batch size: B")
    parser.add_argument('--bs', type=int, default=128, help="test batch size")
    parser.add_argument('--lr', type=float, default=0.01, help="learning rate")
    parser.add_argument('--momentum', type=float, default=0.5, help="SGD momentum (default: 0.5)")
    parser.add_argument('--split', type=str, default='user', help="train-test split type, user or sample")

    # model arguments
    parser.add_argument('--model', type=str, default='mlp', help='model name')
    parser.add_argument('--kernel_num', type=int, default=9, help='number of each kind of kernel')
    parser.add_argument('--kernel_sizes', type=str, default='3,4,5',
                        help='comma-separated kernel size to use for convolution')
    parser.add_argument('--norm', type=str, default='batch_norm', help="batch_norm, layer_norm, or None")
    parser.add_argument('--num_filters', type=int, default=32, help="number of filters for conv nets")
    parser.add_argument('--max_pool', type=str, default='True',
                        help="Whether use max pooling rather than strided convolutions")

    # other arguments
    parser.add_argument('--dataset', type=str, default='mnist', help="name of dataset")
    parser.add_argument('--iid', action='store_true', help='whether i.i.d or not')
    parser.add_argument('--num_classes', type=int, default=10, help="number of classes")
    parser.add_argument('--num_channels', type=int, default=3, help="number of channels of imges")
    parser.add_argument('--gpu', type=int, default=0, help="GPU ID, -1 for CPU")
    parser.add_argument('--stopping_rounds', type=int, default=10, help='rounds of early stopping')
    parser.add_argument('--verbose', action='store_true', help='verbose print')
    parser.add_argument('--seed', type=int, default=1, help='random seed (default: 1)')
    parser.add_argument('--all_clients', action='store_true', help='aggregation over all clients')

    parser.add_argument('--q', type=float,default=1.0, help='parameter of q-fedavg "q"')
    parser.add_argument('--num_clients_per_round', type = int, default=33, help="number of clients per round")
    args = parser.parse_args()  # 解析参数 使用parse_args()解析添加的参数
    return args

这里全都是一些参数的定义。

自定义包的Nets代码为

import torch
from torch import nn
import torch.nn.functional as F


class MLP(nn.Module):
    def __init__(self, dim_in, dim_hidden, dim_out):
        super(MLP, self).__init__()
        self.layer_input = nn.Linear(dim_in, dim_hidden)   # torch.nn.Linear是全连接的层,就代表MLP的全连接层
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout()
        self.layer_hidden = nn.Linear(dim_hidden, dim_out)

    def forward(self, x):
        x = x.view(-1, x.shape[1]*x.shape[-2]*x.shape[-1])
        x = self.layer_input(x)
        x = self.dropout(x)
        x = self.relu(x)
        x = self.layer_hidden(x)
        return x


class CNNMnist(nn.Module):
    def __init__(self, args):
        super(CNNMnist, self).__init__()
        self.conv1 = nn.Conv2d(args.num_channels, 10, kernel_size=5)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, args.num_classes)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = x.view(-1, x.shape[1]*x.shape[2]*x.shape[3])
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        return x


class CNNCifar(nn.Module):
    def __init__(self, args):
        super(CNNCifar, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, args.num_classes)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

定义了几种模型。

下一贴将介绍FedAvg算法

具体代码地址GitHub - shaoxiongji/federated-learning: A PyTorch Implementation of Federated Learning http://doi.org/10.5281/zenodo.4321561

  • 6
    点赞
  • 26
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
联邦学习(FedAvg)是一种分布式机器学习方法,使用多个参与方的本地数据进行模型训练,并在各参与方之间共享更新的模型参数以实现全局模型的训练。以下是一个用PyTorch实现联邦学习(FedAvg)的简单示例: 1.导入所需的库: ```python import torch import torch.nn as nn import torch.optim as optim import torch.utils.data as data ``` 2.定义本地参与方的数据集和模型: ```python class LocalDataset(data.Dataset): def __init__(self, data): self.data = data def __getitem__(self, index): return self.data[index] def __len__(self): return len(self.data) class LocalModel(nn.Module): def __init__(self): super(LocalModel, self).__init__() self.fc = nn.Linear(10, 1) def forward(self, x): return self.fc(x) ``` 3.定义联邦学习(FedAvg)的训练函数: ```python def train_federated(data_loader, model, optimizer): criterion = nn.MSELoss() model.train() running_loss = 0.0 for inputs in data_loader: optimizer.zero_grad() outputs = model(inputs) loss = criterion(outputs, inputs) # 使用本地数据进行训练 loss.backward() optimizer.step() running_loss += loss.item() return model.state_dict(), running_loss / len(data_loader) ``` 4.初始化参与方的数据和模型,并进行联邦学习(FedAvg)的迭代训练: ```python def federated_avg(data, num_epochs, lr): models = [] for i in range(len(data)): model = LocalModel() models.append(model) for epoch in range(num_epochs): model_states = [] avg_loss = 0.0 for i, model in enumerate(models): optimizer = optim.SGD(model.parameters(), lr=lr) data_loader = torch.utils.data.DataLoader(LocalDataset(data[i]), batch_size=32, shuffle=True) model_state, loss = train_federated(data_loader, model, optimizer) model_states.append(model_state) avg_loss += loss avg_loss /= len(models) # 更新模型参数 for model_state in model_states: for param_name, param in model_state.items(): param.data.add_(param) ``` 这是一个简单的使用PyTorch实现联邦学习(FedAvg)示例。在实际应用中,还需要考虑模型参数传输的安全性和通信效率等问题。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值