✅博主简介:本人擅长建模仿真、数据分析、论文写作与指导,项目与课题经验交流。项目合作可私信或扫描文章底部二维码。
联邦学习作为解决数据孤岛问题的有效手段,能在保证用户隐私的情况下联合参与节点共建全局模型。然而,在联邦学习中,参与训练的节点通常来自不同地理位置、组织、用户群体或应用场景,这会导致数据来源和特征分布不同,产生数据异构性问题,进而使联邦学习全局模型偏移。同时,中心式联邦学习受限于中心服务器,会减慢算法收敛速度,降低全局模型推理精度。因此,在不引入大量通信开销的前提下提升数据异构条件下的联邦学习算法性能,成为当前联邦学习领域亟需解决的核心问题。
二、具体优化内容与贡献
-
基于惩罚正则项的联邦学习(FedAvg-Z 算法)
- 针对全局模型与本地模型之间的权重差导致的模型精度下降问题,对 FedAvg 算法中的聚合阶段进行改进,提出 FedAvg-Z 算法。
- 通过增加惩罚正则项,强制局部模型拟合,减小权重差异,从而提高联邦学习的性能。
- 在 MNIST 和 CIFAR-10 数据集上进行两种不同参数设置的实验,结果表明该算法可有效提升联邦学习算法在数据异构场景下的模型准确率。
-
基于节点选择和知识蒸馏的去中心化联邦学习
- 联邦学习通常采用随机采样策略,这会造成模型泛化能力差,且中心节点因需汇聚大量模型参数而导致通信压力大和网络拥塞。
- 采用当前轮次的局部模型性能表现和本地数据集大小来体现不同设备的数据质量差异,通过节点之间的双向选择确定本轮中心节点和邻邦节点,实现联邦学习的去中心化。
- 在联邦学习中加入知识蒸馏机制,以轻微性能损失为代价加快联邦学习运行时间。
- 在 MNIST、CIFAR-10 和 FEMNIST 数据集上进行实验,结果表明该算法可在保证模型性能的同时有效缩短运行时间。
-
基于分割学习的联邦学习
- 针对在计算资源受限的节点中训练复杂模型困难的问题,在联邦学习中引入分割学习。
- 将完整的机器学习模型拆分,节点训练部分网络,其余由计算服务器完成,减少节点端计算负载,同时结合上文节点选择机制加强收敛。
- 在 MNIST、CIFAR-10 和 FEMNIST 数据集上使用 ResNet-18 和 AlexNet 作为复杂网络并针对不同分割位置进行实验。
- 实验结果表明,该算法可有效提高联邦学习在数据异构场景下对复杂模型的处理能力。
三、各优化方法的具体实现与效果
-
FedAvg-Z 算法的实现与效果
- 在 FedAvg 算法的聚合阶段,引入惩罚正则项。当各个节点完成本地模型训练后,在聚合全局模型时,考虑全局模型与本地模型之间的权重差异,并通过惩罚正则项来约束本地模型向全局模型靠近。
- 在 MNIST 和 CIFAR-10 数据集上的实验中,通过调整不同的参数设置,可以观察到 FedAvg-Z 算法在数据异构的情况下,能够有效地提高模型的准确率。这表明惩罚正则项的引入确实有助于减小全局模型与本地模型之间的差异,从而提升联邦学习的性能。
-
基于节点选择和知识蒸馏的去中心化联邦学习的实现与效果
- 对于节点选择,根据当前轮次的局部模型性能表现和本地数据集大小来评估不同设备的数据质量。性能表现较好且数据集较大的节点更有可能被选为中心节点或邻邦节点。这样的双向选择机制可以确保参与联邦学习的节点具有较高的数据质量和代表性。
- 知识蒸馏机制的引入,使得在联邦学习过程中,性能较好的模型可以将其知识传递给性能较差的模型,从而加快整体的学习速度。在实验中,这种方法在保证模型性能的同时,显著缩短了联邦学习的运行时间,减轻了中心节点的通信压力和网络拥塞问题。
-
基于分割学习的联邦学习的实现与效果
- 在联邦学习中引入分割学习时,首先确定模型的分割位置。将复杂的机器学习模型拆分为多个部分,节点只负责训练其中的一部分网络,而其余部分由计算服务器完成。这样可以大大减少节点端的计算负载,使得在计算资源受限的节点上也能进行有效的模型训练。
- 结合节点选择机制,选择合适的节点来训练不同部分的模型,进一步加强了联邦学习的收敛性。在不同数据集上使用 ResNet-18 和 AlexNet 等复杂网络进行实验,并针对不同的分割位置进行调整,可以观察到该算法在数据异构场景下对复杂模型的处理能力得到了有效提高。
-
import numpy as np import torch import torch.nn as nn import torch.optim as optim from torchvision import datasets, transforms # 定义简单的神经网络模型 class SimpleNet(nn.Module): def __init__(self): super(SimpleNet, self).__init__() self.fc1 = nn.Linear(784, 128) self.fc2 = nn.Linear(128, 10) def forward(self, x): x = x.view(-1, 784) x = torch.relu(self.fc1(x)) return self.fc2(x) # 模拟联邦学习中的节点 class Node: def __init__(self, data): self.model = SimpleNet() self.optimizer = optim.SGD(self.model.parameters(), lr=0.01) self.data = data def train(self): for epoch in range(5): for images, labels in self.data: self.optimizer.zero_grad() outputs = self.model(images) loss = nn.CrossEntropyLoss()(outputs, labels) loss.backward() self.optimizer.step() return self.model.state_dict() # 模拟联邦学习的过程 def federated_learning(): transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]) mnist_train = datasets.MNIST('data', train=True, download=True, transform=transform) mnist_test = datasets.MNIST('data', train=False, download=True, transform=transform) # 划分数据给不同节点 node1_data = torch.utils.data.DataLoader(mnist_train, batch_size=64, shuffle=True, sampler=torch.utils.data.SubsetRandomSampler(range(0, 10000))) node2_data = torch.utils.data.DataLoader(mnist_train, batch_size=64, shuffle=True, sampler=torch.utils.data.SubsetRandomSampler(range(10000, 20000))) node3_data = torch.utils.data.DataLoader(mnist_train, batch_size=64, shuffle=True, sampler=torch.utils.data.SubsetRandomSampler(range(20000, 30000))) nodes = [Node(node1_data), Node(node2_data), Node(node3_data)] # 全局模型初始化 global_model = SimpleNet() global_optimizer = optim.SGD(global_model.parameters(), lr=0.01) for round in range(5): local_models = [] for node in nodes: local_weights = node.train() local_models.append(local_weights) # 模拟聚合过程 averaged_weights = {} for key in local_models[0].keys(): averaged_weights[key] = sum([local_model[key] for local_model in local_models]) / len(local_models) global_model.load_state_dict(averaged_weights) # 测试全局模型 test_loader = torch.utils.data.DataLoader(mnist_test, batch_size=64, shuffle=False) correct = 0 total = 0 with torch.no_grad(): for images, labels in test_loader: outputs = global_model(images) _, predicted = torch.max(outputs.data, 1) total += labels.size(0) correct += (predicted == labels).sum().item() print(f'Accuracy of the network on the test images: {100 * correct / total}%') if __name__ == '__main__': federated_learning()