Per-FedAVG源码分析总领

data_dataset.py

这段代码定义了两个自定义的PyTorch数据集类:MNISTDatasetCIFARDataset。这两个类都继承自PyTorch提供的Dataset类,这使得它们能够与PyTorch的数据加载器一起使用。
MNISTDataset是为MNIST数据集设计的,该数据集包含28x28像素的灰度手写数字图像。CIFARDataset是为CIFAR数据集设计的,该数据集包含32x32像素的彩色图像,分为10个不同的类别(飞机、汽车、鸟、猫、鹿、狗、青蛙、马、船和卡车)。
这两个数据集的构造函数和方法都相同,唯一的区别是数据集的名称(这并不影响其功能)。构造函数接受数据子集和目标,或者单独的数据和目标张量。它还接受可选的transform和target_transform函数,这些函数可以在迭代数据集时用于对数据和目标应用变换。
__getitem__方法为给定的索引返回一个数据样本及其相应的目标。如果提供了变换函数,它将应用于数据样本。如果提供了target_transform函数,它将应用于目标。
__len__方法返回数据集中的样本数量。

data_utils.py

这段代码提供了两个函数,get_dataloaderget_client_id_indices,用于加载预处理后的联邦学习数据集的数据加载器,以及获取客户端ID的索引。

  1. get_dataloader 函数接受数据集名称(“mnist” 或 “cifar”)、客户端ID、批处理大小和验证集比例作为参数。它首先检查是否存在预处理的数据集pickle文件,如果没有,则抛出运行时错误。然后,它加载指定客户端的数据集,并将其随机划分为训练集和验证集。最后,它为训练集和验证集创建数据加载器,并返回它们。
  2. get_client_id_indices 函数接受数据集名称作为参数。它加载保存的客户端分离信息(“seperation.pkl”),并返回训练客户端、测试客户端和总客户端数的索引列表。
    这些函数允许在联邦学习设置中为每个客户端加载和分割数据集,以便进行模型训练和验证。使用这些函数,可以轻松地获取每个客户端的数据加载器,以及客户端的索引信息。

data_perprocess.py

这段代码是一个Python脚本,用于预处理MNIST或CIFAR10数据集,以适应联邦学习设置。它随机地为每个客户端分配类别,将数据集分为训练集和测试集,然后将这些数据集保存为磁盘上的pickle文件。它还将客户端数据分布的统计信息保存为JSON文件。
以下是代码的详细说明:

  1. 脚本导入了必要的库,包括PyTorch、torchvision和path.py,用于文件系统操作。
  2. DATASETMEANSTD字典定义了数据集、平均值和标准差的信息,用于数据归一化。
  3. preprocess函数是主要的功能,负责处理数据:
    • 它设置数据集目录和pickle目录。
    • 它为随机数生成器设置种子,以确保可重复性。
    • 它根据客户端总数和训练客户端的比例计算训练和测试客户端的数量。
    • 它定义了用于归一化的数据转换。
    • 它加载原始数据集,并使用randomly_alloc_classes函数为客户端分配类别。
    • 它将各个客户端的数据集保存为pickle文件,并将客户端分为训练和测试的分离信息保存为"seperation.pkl"。
    • 它将数据集分布的统计信息保存为"all_stats.json"。
  4. randomly_alloc_classes函数接受原始数据集、目标数据集类、客户端数量、每个客户端的类别数量以及可选的转换。它使用noniid_slicing函数以非独立同分布的方式切分数据集,为每个客户端创建一个包含指定类别数量的数据集列表。它还计算并返回客户端数据分布的统计信息。
  5. __main__块解析命令行参数,包括数据集选择、客户端总数、训练客户端的比例、每个客户端的类别数量和随机种子,然后调用preprocess函数。
    要运行此脚本,通常会在命令行中执行它,并提供必要的参数。例如:
python preprocess.py --dataset mnist --client_num_in_total 100 --fraction 0.8 --classes 2 --seed 123

这将预处理MNIST数据集,用于有100个客户端的联邦学习设置,其中80%是训练客户端,每个客户端有2个类别的数据,并使用随机种子123确保结果的可重复性。

model.py

这段代码定义了两个模型类,MLP_MNISTMLP_CIFAR10,分别用于MNIST和CIFAR10数据集。这两个模型都是多层感知机(MLP),包含三个全连接层,每个层后面都有一个ELU激活函数。
此外,代码还定义了一个MODEL_DICT字典,用于根据数据集名称获取相应的模型类。
最后,get_model函数接受数据集名称和设备(CPU或GPU)作为参数,并返回相应的模型实例。
以下是代码的详细说明:

  1. elu类是一个自定义的ELU激活函数,它使用torch.where来判断输入是否大于0,并返回相应的值。对于小于0的输入,它使用0.2乘以指数函数减1的结果。
  2. linear类是一个自定义的全连接层,它包含权重矩阵w和偏置b。权重矩阵是通过随机初始化得到的,其标准差是输入特征数目的倒数。
  3. MLP_MNISTMLP_CIFAR10类是MNIST和CIFAR10数据集的MLP模型。它们包含三个全连接层,每个层后面都跟着一个ELU激活函数。在MNIST模型中,输入特征是28x28像素的图像,而在CIFAR10模型中,输入特征是32x32x3的图像。
  4. MODEL_DICT字典包含两个模型类,分别对应MNIST和CIFAR10数据集。
  5. get_model函数根据传入的数据集名称和设备,从MODEL_DICT中获取相应的模型类,并实例化模型。如果设备是GPU,它将模型移动到GPU上;如果设备是CPU,它将模型移动到CPU上。

utils.py

这段代码提供了一个命令行参数解析函数 get_args,以及两个用于模型评估和随机种子固定的辅助函数 evalfix_random_seed

  1. get_args 函数使用 ArgumentParser 从命令行中解析参数。这些参数包括学习率、正则化参数、全局和局部训练轮次、个性化阶段的轮次、批处理大小、验证集比例、数据集名称、每轮训练的客户端数量、随机种子、是否使用GPU、是否在训练过程中进行评估,以及是否记录日志。
  2. eval 函数是一个用于评估模型的函数,它接受模型、数据加载器、损失函数和设备作为参数。该函数在评估模式下运行模型,并计算在数据加载器上的总损失和准确率。它返回总损失和准确率。
  3. fix_random_seed 函数用于固定PyTorch和其他库的随机种子,以确保实验的可重复性。它接受一个种子值作为参数,并设置相应的随机种子。此外,它还清空了GPU缓存,并设置了CUDA的确定性选项。
    这些函数为联邦学习实验提供了一个基础框架,用于处理命令行参数、评估模型和确保实验的可重复性。

perfedavg.py

这段代码定义了一个名为 PerFedAvgClient 的类,用于在联邦学习中训练一个客户端的模型。该类支持两种不同的训练模式:Per-FedAvg(FO)Per-FedAvg(HF)
以下是代码的详细说明:

  1. 初始化方法
    • 客户端ID(client_id):标识客户端的唯一ID。
    • alphabeta:学习率和正则化参数。
    • global_model:全局模型,即所有客户端共享的模型。
    • criterion:损失函数。
    • batch_size:批量大小。
    • dataset:数据集名称,支持MNIST和CIFAR10。
    • local_epochs:本地训练轮次。
    • valset_ratio:验证集比例。
    • logger:日志记录器。
    • gpu:是否使用GPU。
  2. 训练方法
    • 首先,客户端加载全局模型的状态字典。
    • 然后,客户端进行本地训练。对于 Per-FedAvg(HF) 模式,训练分为两个阶段,每个阶段都包含一个梯度计算和模型更新。对于 Per-FedAvg(FO) 模式,代码中包含了注释掉的联邦平均(FedAvg)训练步骤。
    • 如果设置了 eval_while_training,客户端在本地训练前后都会进行评估。
    • 训练完成后,客户端返回序列化的模型状态。
  3. 计算梯度方法
    • 该方法计算模型参数的梯度。
    • 对于 Per-FedAvg(HF) 模式,它还会计算二阶导数。
  4. 个性化评估方法
    • 该方法用于个性化评估。客户端首先加载全局模型的状态字典,然后进行个性化训练和评估。
    • 个性化训练包括一个梯度计算和模型更新。
    • 评估完成后,客户端返回损失和准确率的变化。
  5. 辅助方法
    • get_data_batch:获取一个数据批量的方法。
    • utils.eval:评估模型的方法,从 utils 模块导入。
    • SerializationTool.serialize_model:序列化模型的方法,从 SerializationTool 模块导入。
      这个类为联邦学习中的客户端提供了训练和评估模型的功能,支持不同的训练模式和个性化评估。

main.py

这段代码是一个联邦学习实验的主程序,它负责初始化客户端、进行训练和评估,并记录日志。以下是代码的详细说明:

  1. 命令行参数解析
    • 代码首先使用 get_args 函数解析命令行参数,并使用 fix_random_seed 函数设置随机种子。
    • 检查是否存在日志目录,如果不存在则创建。
  2. 设备选择
    • 如果指定了GPU并且GPU可用,则选择GPU作为设备;否则选择CPU。
  3. 全局模型初始化
    • 使用 get_model 函数根据数据集名称初始化全局模型。
    • 创建日志记录器。
  4. 客户端初始化
    • 创建一个客户端列表,每个客户端都使用 PerFedAvgClient 类初始化。
  5. 训练循环
    • 使用 track 函数跟踪训练进度。
    • 随机选择客户端进行本地训练,并将序列化的模型参数存储在 model_params_cache 中。
    • 使用 Aggregators.fedavg_aggregate 函数聚合模型参数。
    • 反序列化聚合后的模型参数到全局模型。
  6. 评估循环
    • 使用 track 函数跟踪评估进度。
    • 对每个评估客户端执行个性化评估,并记录损失和准确率。
  7. 结果展示
    • 打印评估结果。
  8. 日志记录
    • 如果设置了日志记录,使用 logger.save_html 函数将日志保存为HTML文件。
      这个程序为联邦学习实验提供了一个完整的框架,包括客户端的初始化、模型的训练和评估,以及日志的记录和保存。
  • 31
    点赞
  • 23
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值