联邦学习实战:用python从零实现横向联邦图像分类

本文详细介绍了如何使用PyTorch实现横向联邦学习的图像分类任务,包括环境配置、模型配置、数据集处理、服务端和客户端的实现。通过联邦学习,多个客户端在不共享数据的情况下协同训练ResNet-18模型,提升了模型的泛化能力和隐私保护。
摘要由CSDN通过智能技术生成

本文主要参考了杨强教授的《联邦学习实战》
https://github.com/FederatedAI/Practicing-Federated-Learning/tree/main/chapter03_Python_image_classification

一、环境配置

python3.7
GPU(可选):首先安装CUDA、cuDNN
安装pytorch,pip install torch

二、python实现横向联邦图像分类

2.1配置信息
  • 训练的客户端数量 no_models
  • 全局迭代次数 global_epochs
  • 本地模型的迭代次数 local_epochs
  • 本地训练相关的算法配置,lr(学习率)等
  • 模型信息:本案例用的是ResNet-18图像分类模型
  • 数据信息:本案例使用的是cifar10数据集

其他配置信息,例如是否使用差分隐私,模型聚合策略,可以根据需求自行添加。这里将上面的信息以json格式记录在配置文件中,以便修改。文件命名为conf.json

{
   
  "model_name" : "resnet18",  
  "no_models" : 10,
  "type" : "cifar",
  "global_epochs" : 20,
  "local_epochs" : 3,
  "k" : 6,
  "batch_size" : 32,
  "lr" : 0.001,
  "momentum" : 0.0001,
  "lambda" : 0.1 
}

上面的k指的是每一轮迭代时,服务端会从所有客户端中挑选k个客户端参与训练。
联邦学习在模型训练之前,会将配置信息发送到客户端和服务端保存,如果配置信息更改,也会同时对所有参与方同步。

2.2训练数据集

这里使用的是torchvision的datasets模块内置的cifar10数据集, 项目文件下创建getDatasets.py文件

from torchvision.transforms import transforms
from torchvision import datasets
# 获取数据集
def get_dataset(dir, name):
    if name == 'mnist':
        # root: 数据路径
        # train参数表示是否是训练集或者测试集
        # download=true表示从互联网上下载数据集并把数据集放在root路径中
        # transform:图像类型的转换
        train_dataset = datasets.MNIST(dir, train=True, download=True, transform=transforms.ToTensor())
        eval_dataset = datasets.MNIST(dir, train=False, transform=transforms.ToTensor())
    elif name == 'cifar':
        # 设置两个转换格式
        # transforms.Compose 是将多个transform组合起来使用(由transform构成的列表)
        transform_train = transforms.Compose([
            # transforms.RandomCrop: 切割中心点的位置随机选取
            transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            # transforms.Normalize: 给定均值:(R,G,B) 方差:(R,G,B),将会把Tensor正则化
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ])
        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ])
        train_dataset = datasets.CIFAR10(dir, train=True, download=True, transform=transform_train)
        eval_dataset = datasets.CIFAR10(dir, train=False, transform=transform_test)
    return train_dataset, eval_dataset

2.3服务端

服务端类Server主要包括三种函数

  1. 定义构造函数
  2. 模型聚合函数
  3. 模型评估函数
2.3.1定义构造函数
class Server(object):
	# 定义构造函数
	def __init__(self, conf, eval_dataset):
	  # 导入配置文件
	  self.conf = conf
	  # 根据配置获取模型文件
	  self.global_model = models.get_model(self.conf["model_name"])
	  self.eval_loader = torch.utils.data.DataLoader(
	    eval_dataset,
	    # 设置单个批次大小
	    batch_size=self.conf["batch_size"],
	    # 打乱数据集
	    shuffle=True
	  )
2.3.2模型聚合函数

这里使用的是经典的FedAvg算法

def model_aggregate(self, weight_accumulator):
	# weight_accumulatot存储了每一个客户端的上传参数变化值
	# 遍历服务器的全局模型
	for name, data in self.global_model.state_dict().items():
		# 更新每一层
		update_per_layer = weight_accumulator[name] * self.conf["lambda"]
		# 累加和
		if data.type() != update_per_layer.type():
		 	# 因为update_per_layer的type是floatTensor,所以将起转换为模型的LongTensor(有一定的精度损失)
		 	data.add_(update_per_layer.to(torch.int64))
		else:
		    data.add_(update_per_layer)
2.3.3 模型评估函数

对当前的全局模型,利用评估数据评估当前的全局模型性能

def model_eval
评论 6
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值
>