【阅读笔记】联邦学习实战——用Python从零实现横向联邦图像分类

联邦学习实战——用Python从零实现横向联邦图像分类

前言

本篇学习笔记记录的内容是杨强教授编写的《联邦学习实战》这本书的第三章内容,本来是准备记录在ipad上,但是用博客形式写下来可以跟更多人分享并讨论,这不失为更好的选择。前两章内容为理论基础,简单介绍了联邦学习(想要深入了解的朋友可以阅读杨强教授的《联邦学习》这本书)以及联邦学习的安全机制(同态加密、差分隐私、安全多方计算),这些内容我在之前的博客中有所涉及,就不再此赘述了。
第三章开始部分是配置运行环境,在我的这篇博客中有windows和ubuntu双系统配置教程,有需要的朋友可以看看。GPU 环境的配置需要根据显卡和ubuntu版本来定,这里推荐一篇文章供参考。但是要注意,30系列显卡支持的cuda版本至少为11,我推荐11.4,因为正好有对应的pytorch版本,pytorch的下载指令尽量用官网的,否则人云亦云会走很多弯路。
这一章内容主要从代码角度出发,来分析参与方和服务器在训练过程中的操作。


1. 代码分析

1.1 配置信息

参数设置:

  • 参与方数量
  • 全局迭代次数
  • 本地模型迭代次数
  • 算法配置,包括学习率、mini-batch、优化算法
  • 模型信息
  • 数据信息

在这里插入图片描述
以上的信息会保存在json文件中,在训练前分发给参与方和服务器。

1.2 训练数据集

为了便于选择数据集(可选mnist和cifar),将数据集的获取和预处理过程封装在get_dataset函数中,通过路径和数据集名称进行调用。

def get_dataset(dir, name):

	if name=='mnist':
		train_dataset = datasets.MNIST(dir, train=True, download=True, transform=transforms.ToTensor())
		eval_dataset = datasets.MNIST(dir, train=False, transform=transforms.ToTensor())
		
	elif name=='cifar':
		transform_train = transforms.Compose([
			transforms.RandomCrop(32, padding=4),
			transforms.RandomHorizontalFlip(),
			transforms.ToTensor(),
			transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
		])

		transform_test = transforms.Compose([
			transforms.ToTensor(),
			transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
		])
		
		train_dataset = datasets.CIFAR10(dir, train=True, download=True,
										transform=transform_train)
		eval_dataset = datasets.CIFAR10(dir, train=False, transform=transform_test)
		
	
	return train_dataset, eval_dataset

1.3 服务端

因为在本地运行,所以这里的server较为简单,忽略了网络监控、节点连接失败处理等问题。
server函数组成:

  • 构造函数。拷贝配置信息,根据信息获取模型。比如本章使用的是torchvision中内置的ResNet-18模型。
def __init__(self, conf, eval_dataset):
	
		self.conf = conf 
		
		self.global_model = models.get_model(self.conf["model_name"]) 
		
		self.eval_loader = torch.utils.data.DataLoader(eval_dataset, batch_size=self.conf["batch_size"], shuffle=True)
  • 定义聚合函数。接受客户端上传模型并聚合,本章采用经典的FedAvg算法。
def model_aggregate(self, weight_accumulator):
		# weight_accumulator存储了每个客户端上传的参数变化值
		for name, data in self.global_model.state_dict().items():
			
			update_per_layer = weight_accumulator[name] * self.conf["lambda"]
			
			if data.type() != update_per_layer.type():
				data.add_(update_per_layer.to(torch.int64))
			else:
				data.add_(update_per_layer)
  • 定义模型评估函数。用评估数据评估当前模型的性能,决定是否可以终止训练。
def model_eval(self):
		self.global_model.eval()
		
		total_loss = 0.0
		correct = 0
		dataset_size = 0
		for batch_id, batch in enumerate(self.eval_loader):
			data, target = batch 
			dataset_size += data.size()[0]
			
			if torch.cuda.is_available():
				data = data.cuda()
				target = target.cuda()
				
			
			output = self.global_model(data)
			
			total_loss += torch.nn.functional.cross_entropy(output, target,
											  reduction='sum').item() # sum up batch loss
			pred = output.data.max(1)[1]  # get the index of the max log-probability
			correct += pred.eq(target.data.view_as(pred)).cpu().sum().item()

		acc = 100.0 * (float(correct) / float(dataset_size))
		total_l = total_loss / dataset_size

		return acc, total_l

1.4 客户端

客户端功能是接收服务端下发指令和全局模型,并利用本地数据进行局部模型训练。
client端的组成:

  • 定义构造函数。拷贝配置信息,接收服务端传来的模型。
	def __init__(self, conf, model, train_dataset, id = -1):
		
		self.conf = conf
		
		self.local_model = models.get_model(self.conf["model_name"]) 
		
		self.client_id = id
		
		self.train_dataset = train_dataset
		
		all_range = list(range(len(self.train_dataset)))
		data_len = int(len(self.train_dataset) / self.conf['no_models'])
		train_indices = all_range[id * data_len: (id + 1) * data_len]

		self.train_loader = torch.utils.data.DataLoader(self.train_dataset, batch_size=conf["batch_size"], 
									sampler=torch.utils.data.sampler.SubsetRandomSampler(train_indices))
  • 定义模型本地训练函数。本例用交叉熵作为本地损失函数,梯度下降求解更新参数。
	def local_train(self, model):

		for name, param in model.state_dict().items():
			self.local_model.state_dict()[name].copy_(param.clone())
	
		# 定义优化函数,用于本地训练
		optimizer = torch.optim.SGD(self.local_model.parameters(), lr=self.conf['lr'],
									momentum=self.conf['momentum'])
		# 本地模型训练
		self.local_model.train()
		for e in range(self.conf["local_epochs"]):
			
			for batch_id, batch in enumerate(self.train_loader):
				data, target = batch
				
				if torch.cuda.is_available():
					data = data.cuda()
					target = target.cuda()
			
				optimizer.zero_grad()
				output = self.local_model(data)
				loss = torch.nn.functional.cross_entropy(output, target)
				loss.backward()
			
				optimizer.step()
			print("Epoch %d done." % e)	
		diff = dict()
		for name, data in self.local_model.state_dict().items():
			diff[name] = (data - model.state_dict()[name])
			
		return diff

1.5 整合

整合部分即main函数部分,导入配置信息,初始化客户端和服务端,客户端训练,并将训练模型进行融合。

if __name__ == '__main__':

	parser = argparse.ArgumentParser(description='Federated Learning')
	parser.add_argument('-c', '--conf', dest='conf')
	args = parser.parse_args()
	
	# 导入配置文件
	with open(args.conf, 'r') as f:
		conf = json.load(f)	
	
	# 导入数据集
	train_datasets, eval_datasets = datasets.get_dataset("./data/", conf["type"])
	
	# 初始化服务端和客户端
	server = Server(conf, eval_datasets)
	clients = []
	
	for c in range(conf["no_models"]):
		clients.append(Client(conf, server.global_model, train_datasets, c))
		
	print("\n\n")
	for e in range(conf["global_epochs"]):
		# 随机选取k个客户端
		candidates = random.sample(clients, conf["k"])
		
		weight_accumulator = {}
		
		# 权重初始化
		for name, params in server.global_model.state_dict().items():
			weight_accumulator[name] = torch.zeros_like(params)
		
		for c in candidates:
			diff = c.local_train(server.global_model)
			
			for name, params in server.global_model.state_dict().items():
				weight_accumulator[name].add_(diff[name])
				
		# 模型融合
		server.model_aggregate(weight_accumulator)
		
		acc, loss = server.model_eval()
		
		print("Epoch %d, acc: %f, loss: %f\n" % (e, acc, loss))

2. 模型效果

在这里插入图片描述
在这里插入图片描述

首先分析图一,可以看到联邦学习模型训练准确度和中心化模型学习准确度基本一样,并且学习速率很快,在第10轮迭代时已经接近最终结果,而中心化学习直到第20轮才趋近收敛。
分析图二可知,相对于单点训练,联邦训练的准确度远远高出20%个百分点,并且随着参与的客户端数量增多(即k值越大),性能越好,但是相对的,每轮完成时间也会相对较长。

  • 6
    点赞
  • 66
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 27
    评论
评论 27
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

HERODING77

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值