联邦学习经典算法总结

本文总结了多种联邦学习算法,包括经典FedAvg算法及其在非IID数据上的应用、带有正则化的FedProx、针对非IID特征的FedBN、模型对比学习MOON及FedNova等。这些算法针对不同类型的非独立同分布数据提出了有效的解决方案。

摘要生成于 C知道 ,由 DeepSeek-R1 满血版支持, 前往体验 >

看过很多联邦学习分类分割的文章了,现在来个总结吧。

1.FedAvg Communication-Efficient Learning of Deep Networks from Decentralized Datahttps://arxiv.org/abs/1602.05629

最经典的FL算法

论文里面无收敛分析证明,收敛分析证明需要看这篇文章

关于FedAvg在非IID数据上的趋同

算法:

聚合部分代码:

def average_weights(w):
    """
    Returns the average of the weights.
    """
    w_avg = copy.deepcopy(w[0])
    for key in w_avg.keys():
        for i in range(1, len(w)):
            w_avg[key] += w[i][key]
        w_avg[key] = torch.div(w_avg[key], len(w))
    return w_avg

代码分析:收集所有客户端的模型,然后将模型进行直接平均。

2.FedProx:在FedAvg局部客户端训练时加上一个近似项(就是加上正则化),目的是对偏离全局模型大的客户端进行惩罚,让参与训练的客户端收到约束。

论文:FEDERATED OPTIMIZATION IN HETEROGENEOUS NETWORKShttps://arxiv.org/abs/1812.06127

github:

tensorflow版本:GitHub - litian96/FedProx:异构网络中的联合优化 (MLSys '20)

pytorch版本:GitHub - ki-ljl/FedProx-PyTorch: PyTorch implementation of FedProx (Federated Optimization for Heterogeneous Networks, MLSys 2020).

 参数μ={0.001, 0.01, 0.1, 0.5, 1},参数μ需要调参 

论文里面有收敛分析证明。

算法:

客户端局部代码:

proximal_term = 0.0
for w, w_t in zip(model.parameters(), global_model.parameters()):
     proximal_term += (w - w_t).norm(2)
     loss = loss_function(y_pred, label) + (args.mu / 2) * proximal_term

服务器聚合代码:

    def aggregation(self, index):
        s = 0
        for j in index:
            # normal
            s += self.nns[j].len

        params = {}
        for k, v in self.nns[0].named_parameters():
            params[k] = torch.zeros_like(v.data)

        for j in index:
            for k, v in self.nns[j].named_parameters():
                params[k] += v.data * (self.nns[j].len / s)

        for k, v in self.nn.named_parameters():
            v.data = params[k].data.clone()

3.FedBN: ICLR

FedBN:通过本地批处理规范化对非 IID 特征进行联合学习https://arxiv.org/abs/2102.07623

代码:

GitHub - med-air/FedBN: [ICLR'21] FedBN: 通过本地批处理规范化对非 IID 特征进行联邦学习

FedBN使用局部批次归一化来缓解平均模型之前的特征漂移。

        比如医学成像中不同的扫描仪/传感器,自动驾驶(高速公路与城市)中不同的场景分布,其中本地客户端存储具有与其他客户端不同分布的示例,作者认为局部客户端数据在特征空间中的分布存在偏差,并将这种情况定义为特征漂移。由于医院中使用的不同成像机器和协议,例如不同的强度和对比度,图像外观可能会有很大的不同。具体操作就是:使客户端的BN层保持本地更新,而不需要在服务器上进行通信和聚合。

                   

 首先这幅图告诉我们,局部无BN的话,模型平均并不能得到一个很好的模型,而当局部有BN时,局部训练误差面变得相似  

        这幅图告对于一个最优权重w*,改变γ*会恶化该模型,而对于给定的最优BN参数γ*,改变w*也会降低质量。如果同时平均模型和BN参数将会造成较高的泛化误差,所有要将局部参数保留在本地,只平均不带BN参数的模型。

实验结果:

        同时实验还进行了外部泛化,如果来自另一个领域的新中心加入训练,我们可以将全局模型的非BN层参数传递到这个新中心。这个新的中心将计算自己的均值和方差统计量,并学习相应的本地BN参数。

 

算法:

# aggregate params
        if args.mode.lower() == 'fedbn':
            print("this is fedBN")
            for key in server_model.state_dict().keys():
                if 'norm' not in key:
                    temp = torch.zeros_like(server_model.state_dict()[key], dtype=torch.float32)
                    for client_idx in range(args.node_num):
                        temp += client_weights[client_idx] * models[client_idx].state_dict()[key]
                    server_model.state_dict()[key].data.copy_(temp)
                    for client_idx in range(args.node_num):
                        models[client_idx].state_dict()[key].data.copy_(server_model.state_dict()[key])
        else:
            print("this is not fedBN")
            for key in server_model.state_dict().keys():
                # num_batches_tracked is a non trainable LongTensor and
                # num_batches_tracked are the same for all clients for the given datasets
                if 'num_batches_tracked' in key:
                    server_model.state_dict()[key].data.copy_(models[0].state_dict()[key])
                else:
                    temp = torch.zeros_like(server_model.state_dict()[key])
                    for client_idx in range(len(client_weights)):
                        temp += client_weights[client_idx] * models[client_idx].state_dict()[key]
                    server_model.state_dict()[key].data.copy_(temp)
                    for client_idx in range(len(client_weights)):
                        models[client_idx].state_dict()[key].data.copy_(server_model.state_dict()[key])

 4.MOON CVPR

模型对比联邦学习

MOON的核心思想是利用模型表示之间的相似性来纠正个体的局部训练,即在模型级别进行对比学习。它通过最大化当前局部模型学习的表示与全局模型学习的表示的一致性来纠正局部更新。MOON的目的是减少局部模型学习的表示与全局模型学习的表示之间的距离,增加局部模型学习的表示与先前局部模型学习的表示之间的距离。

 

 实验结果:

  

fedProx的超参数设置为μ={0.001, 0.01, 0.1, 0.5, 1},默认一般为0.01

MOON的超参数μ可以设置为1,也可以设置为10.

5. FedNova

论文:[2007.07481] Tackling the Objective Inconsistency Problem in Heterogeneous Federated Optimization (arxiv.org)  FedNova的意思是每个客户端局部下降的迭代次数是T=N/B,N代表每个客户端的数据样本,B代表小批量数据Batchsize, 设想一下,当N的大小差别很大时,在给定的时间内,不同的客户端之间可能有很大差异,速度较快的客户端会比速度较慢的客户端执行更多的本地更新。(比如一个客户端A的样本大小为1000,一个客户端B的样本大小为200,假设我们规定的一轮的执行时间是4分钟,那么在这4分钟之内,客户端A是不是只执行了1次,但是客户端B已经执行了5次。但是某个客户端执行次数越多它自己收敛的越快,但是对于全局模型来说不一定会朝着更好的方向,因为局部模型之间的差异会越大。)

6. 一篇综述非独立同分布数据孤岛的联邦学习:一项实验研究

(257条消息) 非独立同分布数据孤岛的联邦学习:一项实验研究_我爱计算机视觉的博客-CSDN博客

这篇文章 介绍了六种non-IID数据的划分策略,充分考虑了标签分布的偏差、特征分布的偏差和数量偏差等不同情况。

 

<think>好的,我现在需要回答用户关于联邦学习算法的介绍、原理和实现的问题。首先,我需要确保自己正确理解联邦学习的核心概念。联邦学习是一种分布式机器学习方法,允许各方在本地数据不共享的情况下协作训练模型。这保护了用户隐私,符合GDPR等法规要求。 接下来,我应该介绍联邦学习的背景。用户提供的引用[1]提到联邦学习的背景是数据孤岛问题和隐私保护需求。传统方法需要集中数据,但现实中数据分散且隐私敏感,所以联邦学习应运而生。需要强调其核心目标:数据不动模型动,这样各方可以共享模型参数而非原始数据。 然后是原理部分。联邦学习的基本流程包括初始化、本地训练、参数聚合和迭代优化。用户提到的引用[2]和[3]分别涉及元学习和FedPer算法,这些可以作为原理的补充。例如,FedPer将网络分为基础层和个性化层,基础层由服务器聚合,个性化层保留在本地,这样既能共享知识又适应本地数据特性。需要解释本地训练使用随机梯度下降,服务器用FedAvg算法平均参数。可能还要提到通信效率和安全性,比如加密传输和差分隐私。 在实现方面,用户提供了FedPer的PyTorch示例,我需要解析代码结构。服务器端初始化全局模型,客户端负责本地训练。代码中的模型分割符合FedPer的结构,前三层是基础层,最后一层是个性化层。训练循环中,客户端下载全局参数,训练后上传基础层参数,服务器进行聚合。这部分需要详细说明代码的关键步骤,比如模型分割、参数更新方式。 之后要生成相关问题,覆盖挑战、与传统分布式学习的区别、个性化方法、隐私保护技术等,确保问题具有针对性和延伸性。 最后,检查是否遵循了用户的格式要求:行内公式用$...$,独立公式用$$...$$,正确使用引用标识,结构清晰,使用中文。确保没有使用Markdown格式,代码块正确,引用标注自然。同时,验证LaTeX语法是否正确,例如积分符号、上下标等是否正确书写。</think>### 联邦学习算法详解 #### 1. 背景介绍 联邦学习(Federated Learning)是一种分布式机器学习框架,允许多个参与方(如移动设备、医疗机构)在不共享本地数据的情况下协作训练全局模型。其核心思想是“数据不动,模型动”,通过传递加密的模型参数而非原始数据解决数据隐私和合规性问题[^1]。 #### 2. 核心原理 ##### 2.1 基本流程 联邦学习流程可分为以下步骤: 1. **全局模型初始化**:服务器生成初始模型参数$w_0$; 2. **本地模型训练**:客户端$i$用本地数据$D_i$更新模型参数: $$w_i^{(t+1)} = w^{(t)} - \eta \nabla \mathcal{L}(w^{(t)}; D_i)$$ 3. **参数聚合**:服务器通过联邦平均(FedAvg)算法聚合参数: $$w^{(t+1)} = \sum_{i=1}^N \frac{|D_i|}{|D|} w_i^{(t+1)}$$ 4. **迭代优化**:重复步骤2-3直至收敛[^1][^2]。 ##### 2.2 关键技术特性 - **通信效率**:通过压缩通信数据量(如稀疏化、量化) - **异构性处理**:元学习调整模型适应不同数据分布[^2] - **隐私保护**:差分隐私、同态加密等技术保障参数安全 #### 3. 算法实现(以FedPer为例) 引用[3]提出的FedPer算法将神经网络分为基础层(共享)和个性化层(本地保留)。以下为PyTorch实现核心逻辑: ```python # 服务器端 class GlobalModel(nn.Module): def __init__(self): super().__init__() self.base_layers = nn.Sequential( nn.Linear(784, 256), nn.ReLU(), nn.Linear(256, 128)) # 客户端 class ClientModel(GlobalModel): def __init__(self): super().__init__() self.personal_layers = nn.Linear(128, 10) def forward(self, x): x = self.base_layers(x) return self.personal_layers(x) # 训练循环 for round in range(total_rounds): # 客户端本地训练 for client in selected_clients: client.download(global_model.base_layers.state_dict()) local_train(client.model, local_data) upload_gradients(client.base_layers.parameters()) # 服务器聚合 global_model.aggregate(gradients) ``` #### 4. 典型应用场景 - **医疗领域**:医院联合训练疾病预测模型 - **金融风控**:银行间反欺诈模型协作 - **智能终端**:手机输入法词频预测更新
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值