本文主要参考了杨强教授的《联邦学习实战》
https://github.com/FederatedAI/Practicing-Federated-Learning/tree/main/chapter03_Python_image_classification
一、环境配置
python3.7
GPU(可选):首先安装CUDA、cuDNN
安装pytorch,pip install torch
二、python实现横向联邦图像分类
2.1配置信息
- 训练的客户端数量 no_models
- 全局迭代次数 global_epochs
- 本地模型的迭代次数 local_epochs
- 本地训练相关的算法配置,lr(学习率)等
- 模型信息:本案例用的是ResNet-18图像分类模型
- 数据信息:本案例使用的是cifar10数据集
其他配置信息,例如是否使用差分隐私,模型聚合策略,可以根据需求自行添加。这里将上面的信息以json格式记录在配置文件中,以便修改。文件命名为conf.json
{
"model_name" : "resnet18",
"no_models" : 10,
"type" : "cifar",
"global_epochs" : 20,
"local_epochs" : 3,
"k" : 6,
"batch_size" : 32,
"lr" : 0.001,
"momentum" : 0.0001,
"lambda" : 0.1
}
上面的k指的是每一轮迭代时,服务端会从所有客户端中挑选k个客户端参与训练。
联邦学习在模型训练之前,会将配置信息发送到客户端和服务端保存,如果配置信息更改,也会同时对所有参与方同步。
2.2训练数据集
这里使用的是torchvision的datasets模块内置的cifar10数据集, 项目文件下创建getDatasets.py文件
from torchvision.transforms import transforms
from torchvision import datasets
# 获取数据集
def get_dataset(dir, name):
if name == 'mnist':
# root: 数据路径
# train参数表示是否是训练集或者测试集
# download=true表示从互联网上下载数据集并把数据集放在root路径中
# transform:图像类型的转换
train_dataset = datasets.MNIST(dir, train=True, download=True, transform=transforms.ToTensor())
eval_dataset = datasets.MNIST(dir, train=False, transform=transforms.ToTensor())
elif name == 'cifar':
# 设置两个转换格式
# transforms.Compose 是将多个transform组合起来使用(由transform构成的列表)
transform_train = transforms.Compose([
# transforms.RandomCrop: 切割中心点的位置随机选取
transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
# transforms.Normalize: 给定均值:(R,G,B) 方差:(R,G,B),将会把Tensor正则化
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
transform_test = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
train_dataset = datasets.CIFAR10(dir, train=True, download=True, transform=transform_train)
eval_dataset = datasets.CIFAR10(dir, train=False, transform=transform_test)
return train_dataset, eval_dataset
2.3服务端
服务端类Server主要包括三种函数
- 定义构造函数
- 模型聚合函数
- 模型评估函数
2.3.1定义构造函数
class Server(object):
# 定义构造函数
def __init__(self, conf, eval_dataset):
# 导入配置文件
self.conf = conf
# 根据配置获取模型文件
self.global_model = models.get_model(self.conf["model_name"])
self.eval_loader = torch.utils.data.DataLoader(
eval_dataset,
# 设置单个批次大小
batch_size=self.conf["batch_size"],
# 打乱数据集
shuffle=True
)
2.3.2模型聚合函数
这里使用的是经典的FedAvg算法
def model_aggregate(self, weight_accumulator):
# weight_accumulatot存储了每一个客户端的上传参数变化值
# 遍历服务器的全局模型
for name, data in self.global_model.state_dict().items():
# 更新每一层
update_per_layer = weight_accumulator[name] * self.conf["lambda"]
# 累加和
if data.type() != update_per_layer.type():
# 因为update_per_layer的type是floatTensor,所以将起转换为模型的LongTensor(有一定的精度损失)
data.add_(update_per_layer.to(torch.int64))
else:
data.add_(update_per_layer)
2.3.3 模型评估函数
对当前的全局模型,利用评估数据评估当前的全局模型性能
def model_eval