2023CVPR_Class Balanced Adaptive Pseudo Labeling for Federated Semi-Supervised Learning

联邦半监督学习:Class Balanced Adaptive Pseudo Labeling for Federated Semi-Supervised Learning
发表:2023 CVPR
代码:https://github.com/minglllli/CBAFed

1.1 Introduction

联邦学习(FL)旨在以分散的方式训练机器学习模型,同时保护数据隐私;
近年来,由于隐私保护原因,FL 受到了广泛关注,然而,大多数 FL工作都集中在完全标记数据的监督学习上。但在实践中,大规模训练数据的标记过程既费力又昂贵;
本文重点关注联邦半监督学习(FSSL),既假设很少有客户端拥有完全标记的数据(标记客户端),而其他客户端中的训练数据集未标记(未标记客户端)。
在这里插入图片描述

联邦半监督学习的三个分类
分类一:所有本地客户端数据集中都包含一部分标签数据和一部分无标签数据;
分类二:部分本地客户端数据集中都是标签数据,其他客户端数据集都是无标签数据
分类三:中央服务器包含部分标签数据。
在这里插入图片描述

1.2 Methods

联邦半监督学习(FSSL)的类平衡自适应伪标签(CBAFed)方法流程图:
在这里插入图片描述

总体框架:
①预热阶段,在完全标注客户端训练模型
②中央服务器计算经验分布并获得类平衡自适应阈值
③本地客户端更新模型,无标签客户端生成伪标签以训练模型,上传类数据量和模型给中央服务器
④中央服务器聚合模型,计算类的分布,获得类别平衡自适应阈值
⑤重复以上步骤,直至达到指定的通信轮数
在这里插入图片描述

本地标签客户端上的训练:
类似于ResNet。在ResNet提出之前,所有的神经网络都是通过卷积层和池化层的叠加组成的。人们认为卷积层和池化层的层数越多,获取到的图片特征信息越全,学习效果也就越好。然而随着层数的增加,预测效果反而越来越差。因此提出了ResNet,让神经网络某些层跳过下一层神经元的连接,隔层相连,弱化每层之间的强联系。
在这里插入图片描述
本文提出了一个残差权重连接的方法:

在训练过程中,若通信轮数是我们设置的跳跃次数的倍数的话,本次通信轮数的训练,每隔s次进行一次连接,最后的模型进行公式2的计算,有点类似于EMA方法,如果通信轮数不是我们设置的跳跃次数的倍数的话,则正常训练。
在这里插入图片描述

在标签客户端训练好模型后上传到服务器,无标签客户端下载模型进行推理伪标签,最后使用无标签数据及其伪标签训练无标签客户端的模型。
在这里插入图片描述
标签数据的预训练,训练loss:
在这里插入图片描述
生成伪标签:
在这里插入图片描述

无标签数据及其伪标签构成的训练数据:
在这里插入图片描述

伪标签的类平衡自适应阈值:
在最近的一项半监督工作中[1],提出了类伪标签,并其中时间步 t 处 c 类的灵活阈值计算如下:
在这里插入图片描述

其中βt是c类中所选伪标签的数量与所有类中所选伪标签的最大数量的比率,用于缩放固定阈值τ。
[1]Bowen Zhang, Yidong Wang, Wenxin Hou, Hao Wu, Jindong Wang, Manabu Okumura, and Takahiro Shinozaki.Flexmatch: Boosting semi-supervised learning with curriculum pseudo labeling. Advances in Neural Information Processing Systems, 34:18408–18419, 2021

然而,联邦半监督学习中存在异构数据分区问题。由于非独立同分布数据,标记数据不平衡,因此纯粹使用选定的未标记数据的数量来设计阈值是不合适的。因为某些类别的数据在某些客户端中可能极其稀少,这将导致βt©非常低,从而导致阈值非常低,在某种极端的情况下,如果某些客户端不存在这个分类,会导致这个阈值为0。因此本文作者提出了伪标签的类平衡自适应阈值方法。

伪标签的类平衡自适应阈值计算:
①无标签数据:计算从类C标签中选定的伪标签数量,1为指示函数,因此当我们所选的该标签为类C标签,并且预测值大于Tt的时候选定该标签,+1
在这里插入图片描述
②标签数据的类C数量:标签客户端中有多少类C标签记为多少
在这里插入图片描述

③所有客户端中选定的类c数量:n+m指的是无标签客户端和有标签客户端相加,需要注意的是,在第一轮通信轮次中只有标签客户端进行训练,而无标签客户端参与联邦学习训练时在第二轮次
在这里插入图片描述

④计算训练数据的经验分布:类c的总数/所有类的总数
在这里插入图片描述

⑤计算经验分布的标准差
在这里插入图片描述

⑥计算C类的阈值,同时设置阈值的上限,Th为0.95
在这里插入图片描述
在这里插入图片描述

针对长尾不平衡分布问题提出创新:
举个简单的例子,如果要做一个动物分类数据集,猫狗等常见数据可以轻轻松松的采集数以百万张的图片,但是罕见动物的样本就很少,所以尾类数据很少会被分类为正确的类。为了增强学习能力并利用尾类数据中的信息(而不是直接丢弃最大置信度较低的数据),本文提出了这样一个理论,尽管自适应伪标记可以降低尾类的阈值,但PL将选择更少的数据,因为来自尾类的数据很少会被分类为正确的类。为了增强学习能力并从这些类中发现未标记数据,本文作者利用其中的信息,而不是直接丢弃最大置信度较低的数据,即“无信息”的未标记数据。
在这里插入图片描述
将第一大置信度置为0:
在这里插入图片描述
判断是否为预测错误的伪标签,也就是是否要抱歉该预测的第一大置信度而去研究第二大置信度,若满足条件:最大置信度小于自适应阈值,以及我们的第二大置信对应的标签是尾类,我们将其视为错误的伪标签分类:
在这里插入图片描述

模型聚合
①权重计算:
在这里插入图片描述
②模型聚合:
在这里插入图片描述

1.3 Experiments

数据集:四个图像分类数据集( SVHN、CIFAR-10、CIFAR-100、Fashion MNIST)和一个医学图像分类数据集:ISIC 2018(针对黑色素瘤的皮肤病变分析);

客户端设置:训练数据集包含 10 个客户端:1 个已标记,9 个未标记。我们使用狄利克雷分布 Dir(γ),其中 5 个数据集的 γ = 0.8 在客户端中生成非独立同分布数据分区;

实验设置:使用 0.9 的 SGD 优化器,并使用 PyTorch 实现方法。对所有数据集采用 PyTorch 中的 ResNet18 。为了公平比较,在所有 FSSL 方法中使用相同的网络架构和训练协议,包括优化器、数据预处理等。对于标记的客户端,本地训练纪元设置为 11,对于未标记的客户端,设置为 1。

对比:我们将我们的方法与最先进的 FSSL 方法进行比较,包括 RSCFed、FedIRM 和 Fed-Consist。由于FedIRM和
Fed-Consist并非设计用于处理非IID FSSL,我们将标记客户端的权重扩大到约 50%,其他 9 个未标记客户端共享剩余的 50% 权重在每个 FSSL 同步轮中,实现这些方法时。我们还将我们的网络与使用所有客户端作为上限进行训练的 FedAvg以及仅使用一个标记客户端作为下限进行训练进行比较。此外,为了显示残差权重连接的有效性,我们报告了仅使用残差权重连接仅使用一个标记客户端训练的FedAvg 的结果。由于Fed-Consist 使用传统的基于批次的伪标记方法,因此我们使用我们提出的固定伪标记来报告Fed-Consist 的结果,而不增加标记客户端的权重。

使用所有客户端作为上限进行训练的 FedAvg以及仅使用一个标记客户端作为下限进行训练进行比较。此外,为了显示残差权重连接的有效性,仅使用残差权重连接仅使用一个标记客户端训练的FedAvg 的结果。Fed-Consist +是在Fed-Consist 的基础上使用我们提出的Fixed PL:
在这里插入图片描述

在一个带标签的客户端上进行训练,与不带残余权重连接的 FedAVG 相比,带残余权重连接的 FedAVG实现了更好的性能。
(a)显示了训练期间的测试精度曲线。由于标记客户端中训练数据分布不平衡,FedAVG 训练过程中测试精度不稳定(图中没有 res-weight)。但是,如果使用我们的剩余权重连接进行训练,测试精度曲线会更加稳定,并且性能也会得到增强(w/ res-weight和 w/ res-weight*)。W/ res-weight* 表示我们仅显示具有跳跃权重连接的通信轮次。

在这里插入图片描述
Vision Transformer(ViT)已被证明对异构数据和分布变化更加鲁棒,我们使用 ViT-Tiny 作为网络主干,在 SVHN 数据集上进行实验。补充材料中提供了实施细节。表2显示了比较结果。我们的方法可以胜过所有其他方法。类似地,在一个标记客户端上进行训练,带残差权重连接的 FedAVG 超过了不带残差权重连接的 FedAVG,这意味着我们的残差权重连接在 ViT 上也有效。在这里插入图片描述

将整个训练数据分为 10 个客户端,其中两个标记的客户端和八个未标记的客户端。
在这里插入图片描述

消融实验
在这里插入图片描述

1.4 conclusion

提出了用于联邦半监督学习(FSSL)的类平衡自适应伪标签(CBAFed); 在CBAFed中,提出了一种固定的伪标记策略;
为了处理FSSL的Non-IID设置,我们提出了一种类平衡自适应阈值选择方法来选择更好的伪标签。
设计了残差权连接方法,使模型更好地达到最优; 在五个数据集上评估 CBAFed,其性能显示了我们方法的优越性。

局限性:与其他 FSSL 方法一样,如果标记客户端中的数据数量极少,则无法保证预训练时的良好模型,因此最终的全局模型可能表现不佳。

  • 4
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
Gatys et al. (2016) proposed an algorithm for style transfer, which can generate an image that combines the content of one image and the style of another image. The algorithm is based on the neural style transfer technique, which uses a pre-trained convolutional neural network (CNN) to extract the content and style features from the input images. In this algorithm, the content and style features are extracted from the content and style images respectively using the VGG-19 network. The content features are extracted from the output of one of the convolutional layers in the network, while the style features are extracted from the correlations between the feature maps of different layers. The Gram matrix is used to measure these correlations. The optimization process involves minimizing a loss function that consists of three components: the content loss, the style loss, and the total variation loss. The content loss measures the difference between the content features of the generated image and the content image. The style loss measures the difference between the style features of the generated image and the style image. The total variation loss is used to smooth the image and reduce noise. The optimization is performed using gradient descent, where the gradient of the loss function with respect to the generated image is computed and used to update the image. The process is repeated until the loss function converges. The code for this algorithm is available online, and it is implemented using the TensorFlow library. It involves loading the pre-trained VGG-19 network, extracting the content and style features, computing the loss function, and optimizing the generated image using gradient descent. The code also includes various parameters that can be adjusted, such as the weight of the content and style loss, the number of iterations, and the learning rate.
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

qq_46738968

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值