SemiFL: Semi-Supervised Federated Learning for Unlabeled Clients with Alternate Training论文讲解

前言

  这是前段时间组会时候讲解的一篇论文,这篇论文下载地址:arxiv下载地址,组会PPT和讲稿下载地址:(待更新)。
  这是篇半监督学习和联邦学习相结合的论文,本文的创新点在于提出了两个方法:1.每个通信轮次后,用服务器端的有标记数据对全局模型进行微调;2.在全局模型下发至客户端时,用刚下发的全局模型生成伪标签。并且本文对强数据增强进行了理论分析。
  本文参考了FixMatch论文中生成伪标签的方法,并使用了MixMatch论文中使用的Mixup方法对未标记数据进行混合产生新的数据集,达到扩充数据集的目的。本文还使用了HeteroFL论文中提到的静态批归一化(sDN)方法。

1. 引言

  作者介绍:这篇论文的名称是《SemiFL:通过替代训练为未标记客户提供半监督联邦学习》,作者是杜克大学的刁恩茂、瓦希德·塔罗克 和明尼苏达大学双城分校的丁杰。
  动机:现有的大多数联邦学习研究都集中在假设客户端拥有真实标签的监督学习任务上。但在许多实际情况下,客户端不是专家,无法标注他们的数据。例如,一个医疗系统可能有一个拥有领域专家和有限数量标签数据的中心节点(服务器),而有很多非专家和大量未标注数据的客户端。

  下图展示了具有标记数据的服务器可以通过与具有未标记数据的分布式客户端一起工作而无需数据共享来显着提高其学习性能。

在这里插入图片描述

2. 方法

2.1 要解决的问题

  在半监督学习分类任务中,我们有两个数据集,即监督数据集S和无监督数据集 U。(无监督)未标记数据集的大小往往远大于(监督)标记数据集的大小。
  下图是 FL 与 SSFL 方法的普通组合的对比图。FixMatch与 FedSGD 配合使用准确率比较高,但是需要批量梯度聚合,因此通信效率不高。FixMatch和FedAvg相结合的方法允许本地进行多个epoch再上传模型参数到服务器,减少了通讯轮次,但是准确率随通信轮数的增加而降低。因此,如何在训练多个本地 epoch 的通信高效的联邦学习(FL)场景中结合半监督(SSL)方法需要我们考虑。

在这里插入图片描述
知识扩充

  问题1. 使用FedAvg + FixMatch时准确率随通信轮数增加而降低,可能有几个原因:
  1.过拟合:模型在客户端的本地数据上训练太多轮次,而这些数据可能并不代表整个数据分布,模型可能会对这些特定的数据过拟合。当这个本地过拟合的模型参数被送回服务器并与其他模型参数聚合时,它可能会损害全局模型的泛化能力。
  2.错误的伪标签积累:初期生成的伪标签质量不高,错误的伪标签可能会在训练过程中积累。因为这些错误的标签被用来进一步训练模型,它们会导致模型性能下降。

  问题2. FixMatch与 FedSGD 配合使用准确率比较高,但通信效率不高的原因:
  1.FedSGD(联邦随机梯度下降)是一种联邦学习算法,客户端在本地计算梯度后发送给服务器聚合,导致数据传输量大。FixMatch是半监督学习法,通过为未标记数据生成伪标签来训练模型,要求模型对这些数据预测高置信度以产生高质量伪标签。在联邦学习中,这要求频繁通信以同步全局模型和客户端模型。FedSGD由于每次本地计算后即通信,自然适合需要频繁更新的场景,如与FixMatch结合使用。
  2.然而,这可能导致通信效率低下,尤其是在客户端众多时。尽管FedSGD能保持同步,但可能需多通信轮次以达到满意性能,而FedAvg通过减少通信频率,可能在效率上更优,但可能降低与如FixMatch等SSL方法的兼容性。

2.2 强数据增强

在这里插入图片描述

这幅图解释说明了强数据增强在半监督学习(SSL)中的作用,其具体过程和作用如下所示:
实线:最优决策边界。虚线:高置信度伪标签的边界。Pu:指未标记数据的概率分布区域。Pl:指标记数据的概率分布区域。)
  1. 从未标记数据中挑选样本:图中以蓝色三角形和X表示的点代表从未标记数据集中选出的数据。
  2. 生成高置信度的伪标签:对这些蓝色的数据(未标记数据)进行预测,生成高置信度的伪标签。图中的虚线区域代表了高置信度的伪标签的边界。
  3. 应用强数据增强技术:将选定的未标记数据进行强数据增强处理,这种处理方式可能包括大幅度的颜色变换、剪切、旋转等操作,目的是创造出与原始样本在视觉上有明显区别的新样本。增强后的样本在图中以红色的三角形和X表示。
  4. 将增强样本视为训练数据:将增强样本及其伪标签一起被用作训练数据,输入到模型中进行学习。
强数据增强的目的:由于增强样本覆盖了原始未标记样本可能不足的区域,扩展训练数据集

知识扩充:
  对于强数据增强的理解
  强数据增强不仅仅是改变数据以避免过拟合,更是一种策略,它将高质量的数据点转变为质量较低的数据点。 这里的“质量”指的是数据在视觉上或者从分类难度上的变化,而不是指数据本身的真实性或信息含量。所谓的质量较低,意味着增强后的数据在视觉上可能更加模糊、更具挑战性或与原始数据在某种程度上不同适用于有标记数据覆盖不足的数据领域,即原始数据集中缺乏标记数据的部分。强数据增强可以确保在这些区域有足够的“观测”,从而使模型可以从更多样化的数据中学习。

  下图是基于 RandAugment 技术的强数据增强示例。随着扭曲幅度的增加,增强的强度也会增加。这里,“Original”表示原始图像,“ShearX”表示沿水平轴剪切图像;“AutoConstrast” 是自动对比度调整,通过增加图像的对比度来使细节更清晰。

在这里插入图片描述

2.3 交替训练(两种方法)

在这里插入图片描述

  图 4(a):普通的联邦半监督过程
  1.服务器模型使用标注数据进行训练,而客户端模型使用未标注数据。
  2.模型权重:Ws t-1 表示前一轮迭代后服务器模型的权重,Wu,1t 到 Wu,Mt 表示当前迭代后各个客户端模型的权重。
  3.生成伪标签:在每一轮通信后,客户端使用自己的模型在各自的未标注数据批次上生成伪标签。
  4.模型聚合(Aggregate): 各个客户端使用其未标注数据和伪标签训练本地模型后,将这些模型发送到服务器进行聚合,更新服务器模型的权重 Wst。

  在通信高效的联邦学习环境中,无法保证伪标签的质量会在训练过程中提升,因为允许本地客户端进行多个epoch的训练,这可能会降低性能。过拟合:模型在客户端的本地数据上训练太多轮次,而这些数据可能并不代表整个数据分布,模型可能会对这些特定的数据过拟合。错误的伪标签积累:初期生成的伪标签质量不高,错误的伪标签可能会在训练过程中积累。
  因此,作者提出(两个创新点):
  1.使用有标签的数据重新训练全局模型(Fine-tune global model with labeled data),这样做可以为下一轮的活跃客户端提供一个与前一轮相当或更好的模型来生成伪标签。
  2.当活跃客户端从服务器接收到全局模型后,立即使用该模型对未标记的数据生成伪标签(Generate pseudo-labels with global model) ,这种方法的好处是伪标签的质量不会因为客户端在本地训练过程中而下降。为我们允许本地客户端训练多个 epoch,这可能会降低性能(如FedAVG+FixMatch)。
  上述两个创新点体现在图4(b)中。

2.4 公式

  公式 (1) 描述了服务器如何使用梯度下降更新其模型权重。这涉及到计算损失 Ls 并取其梯度,然后更新权重Ws。模型f的输入是经过弱数据增强α处理的数据批次xb,Ws 是服务器模型当前的权重。α(xb):表示对数据批次xb 应用的弱数据增强,如随机水平翻转或随机裁剪。这种增强可以增加数据的多样性,并有助于防止过拟合。
在这里插入图片描述
  公式 (2) 展示了如何在客户端生成伪标签。客户端模型权重Wu,m 使用服务器传输的权重Ws 更新,然后使用增强后的未标注数据α(xu,m)和Wu,m 生成伪标签 yu,m。

在这里插入图片描述

  公式 (3) 定义了如何构建高置信度数据集Du,mfix,类似于FixMatch算法。只有当伪标签的置信度超过阈值τ时,数据点(xu,m,yu,m) 才会被包括在Du,mfix 中。
在这里插入图片描述

  公式(4) 描述了如何从高置信度数据集Du,mfix 中创建一个新的数据集 Du,mmix,这个新数据集将用于Mixup增强,这表示从Du,mfix 中进行采样,采样的数量等同于 Du,mfix 中的元素数量,采样是带有替换的,意味着同一个数据点可以被多次选中。这个公式的目的是确保Du,mmix 数据集中的数据点足够多,以便用于后续的Mixup数据增强过程。Mixup增强涉及到将两个数据点按一定比例混合,以产生新的、合成的数据点。

在这里插入图片描述
补充: Mixup混合方法步骤:
  1.取两个样本 (x1,p1)和 (x2,p2),其中x代表样本的特征,而p代表样本的标签
  2.MixUp会生成一个新的样本 (x′,p′),其中x′是x1和x2的线性组合, p′是p1和p2的线性组合。
  3.线性组合是通过一个从Beta分布中采样的参数λ来控制的。Beta分布是一个取值在0到1之间的概率分布,由一个超参数α控制。在这个上下文中,λ和1−λ分别表示保留第一个样本和第二个样本的比例。这里的λ′实际上是为了确保λ更倾向于接近1或者0,这样新生成的样本x′会更接近x1或者x2,而不是位于它们正中间。这通过取λ和1−λ中的最大值来实现。
  最终,通过这种方法产生的数据集能够提供额外的正则化和增强模型对新数据泛化能力的训练样
本。
  下图是在MixMatch论文中Mixup方法的流程

Mixup示意图

在这里插入图片描述

  公式(5) 描述了如何计算Mixup数据增强过程中的损失函数,Lfix是用于fix数据的损失函数。它计算模型f 在增强后的fix数据A(xbfix) 上的输出与fix标签ybfix 之间的差异。
  Lmix是用于混合数据的损失函数。它是两部分的加权和,这两部分分别计算模型在弱数据增强的混合数据α(xmix) 上的输出与固定标签ybfix 以及混合标签ybmix 之间的差异。

在这里插入图片描述

公式 (6) 定义了客户端如何更新其模型参数,
在这里插入图片描述

2.5 伪代码

伪代码部分比肩简单,通过对上述公式的理解能够捋清楚伪代码流程:

在这里插入图片描述
在这里插入图片描述

2.6 实验设置

在这里插入图片描述

2.7 结果

  与SSL方法的比较: 我们在表1中展示了完全监督和部分监督情况以及现有SSL方法的结果进行比较。完全监督情况是指所有数据都被标记,而在部分监督情况下,我们只用部分标记的 NS 数据。我们的结果明显优于部分监督的情况。换句话说,SemiFL 可以在通信高效的场景中显着提高带标记服务器与未标记客户端的性能。我们的方法与最先进的 IID 数据分区 SSL 方法相比具有竞争力。此外,可以预见的是,随着客户端对非独立同分布数据分区的标签变得更加偏斜,我们的方法的性能会下降。然而,即使是最标签倾斜的未标记客户端也可以使用我们的方法提高标记服务器的性能。我们工作的一个限制是,随着监督数据大小的减小,SemiFL 的性能比集中式 SSL 方法下降得更多。我们认为这是因为我们无法在一批数据中同时训练标记和未标记数据。
  与 FL 和 SSFL 方法的比较: 我们将我们的结果与表 1 中最先进的 FL 和 SSFL 方法进行比较。我们证明 SemiFL 可以与使用完全监督数据训练的最先进的 FL 结果相媲美。值得一提的是,在非 IID 数据分区情况下,SSFL 可能优于 FL 方法,因为服务器具有一小组带标签的 IID 数据。我们还证明我们的方法明显优于现有的 SSFL 方法。现有的 SSFL 方法无法与最先进的集中式 SSL 方法紧密配合,即使它们的底层 SSL 方法是相同的。此外,现有的 SSFL 方法无法优于部分监督的情况,这表明它们降低了标记服务器的性能。特别是,FedMatch 为服务器和客户端分配不相交的模型参数,FedRGD 为服务器模型分配更高的权重进行聚合。两种方法都不会直接使用标记数据对全局模型进行微调,并使用接收到的全局模型生成伪标签。据我们所知,所提出的 SemiFL 是第一个 SSFL 方法,它实际上提高了标记服务器的性能,并且性能接近最先进的 FL 和 SSL 方法。

在这里插入图片描述
消融研究:
  1. 使用标记数据微调全局模型”和“使用全局模型生成伪标签”
  通过使用 CIFAR10 数据集对替代训练的每个组成部分进行消融研究。 “使用标记数据微调全局模型”和“使用全局模型生成伪标签”的结合显着提高了性能。

在这里插入图片描述

  2.静态批量归一化(sBN)
  仅用服务器端更新的sBN(server only): 这指的是在每次通信轮次中,只使用服务器端的数据来更新全局sBN统计信息,即全局模型的均值(μ)和方差(σ^2)。
  既用服务器又用客户端更新的sBN(server only server and clients): 这指的是在每次通信轮次中,结合来自服务器端和所有客户端的数据来更新全局sBN统计信息。这样做可能需要从每个客户端上传它们各自的BN统计信息,从而可以更精确地估计整个分布的均值和方差。
  我们通过消融实验可以看到:仅使用服务器数据来更新全局sBN统计信息,并不会降低性能。

在这里插入图片描述
  3.局部训练时期 E 的数量、Mixup 数据增强和全局 SGD 动量 βg
  我们对实验中采用的训练技术进行了消融研究。我们研究了局部训练时期 E 的数量、Mixup 数据增强和全局 SGD 动量 βg [10] 的效果,如表 5 所示。由于收敛缓慢,局部训练时期较少会显着损害性能。 Mixup 数据增强使 CIFAR10 数据集的准确度提高了约 2%。它表明将强数据增强与 Mixup 数据增强相结合对于训练未标记数据是有益的。全球势头略微改善了结果。

在这里插入图片描述

伪标签的质量:
  下图(图 6)是通过使用 CIFAR10 数据集测量伪标签的质量来进行替代训练的消融研究。 “Fine Tune”和“Global”分别指的是我们提出的方法“使用标记数据微调全局模型”和“使用全局模型生成伪标签”。 “Average”是指普通的 FL 方法,它直接取标记服务器和未标记客户端的模型参数的平均值。 “Training”是指在每批本地训练中生成伪标签。
  Pseudo Accuracy (伪标签精度): 这个指标量化了生成的伪标签的准确性。理想情况下,这些伪标签会尽可能接近真实的标签。
  Threshold Accuracy (阈值精度): 这个指标衡量了只有高于某个置信度阈值的伪标签的准确性。这有助于去掉那些可能不准确的伪标签。伪标签会根据模型预测的置信度被接受或拒绝。这通常涉及到设置一个置信度阈值,只有当模型对其预测足够确定时(例如,预测概率超过某个阈值),才会为样本分配伪标签。
  Label Ratio (标签比例): 这个指标衡量了达到置信度阈值的伪标签所占的比例。例如,如果有1000个无标签样本,模型对其中600个样本的预测置信度超过了设定的阈值,那么Label Ratio就是0.6(600/1000)。
  Label Ratio随着通信轮次增加而上升,这可能反映了随着模型训练的进行,模型对于更多样本产生了足够高的置信度预测,从而增加了伪标签的使用量。这可以作为模型学习进度和对无标签数据把握程度的一个指标。

在这里插入图片描述

3. 扩展知识

1.批归一化(Batch Normalization, BN)可以加速模型训练并提高稳定性。BN通过对每层的输入数据进行规范化处理来减少内部协变量偏移。然而,BN需要在每层计算数据的均值和方差,这在分布式学习中会导致问题:
  1.通信成本高:如果把这些统计信息上传到服务器,将会增加通信成本。
  2.隐私问题:这可能会泄露关于本地数据的信息。
  为了解决这些问题,提出了一种静态批归一化方法(sBN),它在训练期间不追踪运行中的统计信息,而是简单地标准化批处理数据。sBN不考虑本地的运行统计信息,因为这些信息可能会动态变化,而是在训练结束后,服务器会从各个本地客户端中逐个查询并累积更新全局的BN统计信息。

2.强数据增强和弱数据增强:
  弱数据增强改变较小,目的是轻微地增加数据的多样性;强数据增强则进行较大的改变,以促进模型学习到更鲁棒的特征。下面是更详细的分析:
(1)弱数据增强(Weak Data Augmentation):
目的: 弱数据增强旨在进行轻微的修改来增加数据集的多样性,而不显著改变数据的核心内容和结构。
方法: 常用的弱数据增强技术包括水平翻转、轻微的裁剪和旋转、亮度和对比度调整等。
效果: 这种方法足以提供小的变化,帮助模型抵抗过拟合,但又不至于改变图像的基本特征。
使用场景: 通常在训练开始时或者在不需要模型进行过于复杂推理的情况下使用。
(2)强数据增强(Strong Data Augmentation):
目的: 强数据增强的目的是显著改变数据,迫使模型学习到更为泛化的特征表示,能够处理更广泛的变化。
方法: 强数据增强包括随机剪切、旋转到极端角度、颜色扭曲、添加噪声、应用滤镜等,这些变化比弱数据增强更加激进。
效果: 强数据增强生成的图像可能在视觉上与原始图像有很大的不同,但它们仍然保留了被分类任务所需的关键信息。
使用场景: 在模型需要对数据的高层次特征进行推理,或者需要从数据中学习更抽象的表示时使用。
  联系与区别:
联系: 两者都是数据增强的策略,用于通过人为扩展数据集的多样性来提高模型的泛化能力。
区别: 主要在于应用的强度和目标。弱数据增强改变较小,目的是轻微地增加数据的多样性;强数据增强则进行较大的改变,以促进模型学习到更鲁棒的特征。
  在实际应用中,弱数据增强和强数据增强可以结合使用。例如,在半监督学习和自监督学习中,一个常见的模式是对有标签的数据使用弱数据增强,而对生成伪标签的未标记数据使用强数据增强,这样可以同时保持标签信息的准确性并推动模型从更复杂的变化中学习。

3.充分传输(adequate transmission)理论:
  论文提出了一种理论,它是基于“充分传输”(adequate transmission)的直观假设。这个理论假设强调,通过高置信度的未标记数据生成的增强数据能够充分覆盖在预测时感兴趣的数据领域。换句话说,由未标记数据所展示的可靠信息能够传递到标记数据训练不足的数据领域。
  在此基础上,这段理论描述了以下几个关键点:
1.可靠信息的传输:未标记数据提供的信息可以传递到通常由于标记数据不足而训练不充分的数据区域。这种信息的传递使得模型能在那些未标记或标记不充分的领域内进行更准确的预测。
2.研究范围的限制:作者指出,他们没有全面地研究自监督学习(SSL),而是专注于非参数核方法(nonparametric kernel-based classification learning)的一个特定类别。非参数核方法是一类基于数据点之间相似性度量的分类学习算法。
3.统计风险率分析:作者基于这些核方法,进行了解析上可处理的统计风险率分析,即定量评估模型预测错误的风险随数据数量的减少而变化的速率。

持续更新中……

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值