参考书目《联邦学习实战》 杨强
在阅读本书的过程中,我尝试根据书中的代码,自己实现横向联邦学习中的图像分类任务,这里是我对代码和逻辑的理解还有出现的问题,希望对大家的学习有所帮助。
下面的表格是一些实验基本信息:
配置信息 | 解释 |
---|---|
数据集 | Cifar10(其将样本划分后给每个客户端作为本地数据) |
全局迭代次数 | 服务器和客户端的通信次数 |
本地模型迭代次数 | 每一次客户端训练的轮数,各个客户端可以相同,也可以不同 |
一些其它基础的模型配置信息在json文件中给出: |
{
"model_name" : "resnet18",
"no_models" : 10,
"type" : "cifar",
"global_epochs" : 20,
"local_epochs": 3,
"k" : 6,
"batch_size" : 32,
"lr" : 0.001,
"momentum" : 0.0001,
"lambda" : 0.1
}
获取训练数据集函数dataset.py:
import torchvision.datasets as dataset
import torchvision.transforms as transform
def get_dataset(dir, name):
if name == 'mnist':
# 获取训练集和测试集
train_dataset = dataset.MNIST(dir, train=True, download=True, transform=transform.ToTensor()) # 设置下载数据集并转
# 换为torch识别的tensor数据类型
eval_dataset = dataset.MNIST(dir, train=False, transform=transform.ToTensor()) # 测试集
elif name == 'cifar':
transform_train = transform.Compose([ # 数据增强操作,训练集的预处理
transform.RandomCrop(32, padding=4), # 随机剪裁,大小为32*32,添加4个像素的填充内容
transform.RandomHorizontalFlip(), # 随机垂直方向的翻转
transform.ToTensor(),
transform.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) # 归一化操作,数值是抽样得到的,无需考虑太
# 多,分别是均值和标准差
])
transform_test = transform.Compose([ # 对测试集进行预处理
transform.ToTensor(),
transform.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
])
train_dataset = dataset.CIFAR10(dir, train=True, download=True, transform=transform_train) # 获得训练集
eval_dataset = dataset.CIFAR10(dir, train=False, transform=transform_test) # 获得测试集
return train_dataset, eval_dataset
这是一个简单的测试,采用本地模拟的方式进行客户端和服务器的交互,现在定义一个服务端类Server,其中的聚合函数采用的是FedAvg算法更新全局模型,公式如下:
G t + 1 = G t + λ ∑ i = 1 m ( L i t + 1 − G