paper总结(8)FlexMatch: Boosting Semi-Supervised Learning with Curriculum Pseudo Labeling

问题背景

最近提出的FixMatch在大多数半监督学习(SSL)基准测试上取得了最先进的结果。然而,与其他现代SSL算法一样,FixMatch使用预定义的所有类的常数阈值来选择有助于训练的未标记数据,没有考虑到不同类的不同学习状态和学习困难。

FixMatch和其他流行的SSL算法(如伪标记和无监督数据增强(UDA))的缺点是,它们依赖固定的阈值来计算无监督损失,只使用预测置信度高于阈值的无标记数据。虽然该策略可以确保只有高质量的无标记数据有助于模型训练,但它忽略了相当多的其他无标记数据,特别是在训练过程的早期阶段,只有少数无标记数据的预测置信度高于阈值。此外,现代SSL算法平等地处理所有类,而不考虑它们不同的学习困难。

针对这些问题,作者提出了课程伪标记(Curriculum Pseudo Labeling, CPL)策略,这是一种考虑到每个班级学习状况的半监督学习的课程学习策略。CPL用灵活的阈值代替预先定义的阈值,根据当前的学习状态动态调整每个类的阈值。值得注意的是,这个过程不引入任何额外的参数(超参数或可训练参数)或额外的计算(向前或向后传播)。将这种课程学习策略直接应用于FixMatch,并将改进算法称为FlexMatch。

贡献:

提出了课程伪标记(CPL),这是一种动态利用未标记数据进行SSL的课程学习方法。它几乎是免费的,并且可以很容易地集成到其他SSL方法。

CPL在通用基准测试上显著提高了几种流行SSL算法的准确性和收敛性能。特别是FlexMatch,即FixMatch和CPL的集成,实现了最先进的结果。

我们开源了TorchSSL,这是一个统一的基于pytorch的半监督学习代码库,用于公平地研究SSL算法。TorchSSL包括流行的SSL算法及其相应的训练策略的实现,并且易于使用或定制。

方法:FlexMatch

Curriculum Pseudo Labeling(课程伪标签)

根据学习状态动态确定阈值并非易事。最理想的方法是计算每个类的评估精度,并使用它们来缩放阈值,如下:

其中Tt(c)为时间步t下c类的灵活阈值,at(c)为相应的评价精度。这样,较低的准确率表明该类的学习状态不太令人满意,这将导致较低的阈值,鼓励该类的更多样本被学习。由于不能在模型学习过程中使用评估集,因此可能必须从训练集中分离出一个额外的验证集来进行准确性评估。然而,这种做法显示了两个致命的问题:首先,在SSL场景下,这种与训练集分离的标记验证集是昂贵的,因为标记数据已经稀缺。其次,为了在训练过程中动态调整阈值,必须在每个时间步t上连续进行精度评估,这将大大降低训练速度。

CPL使用了另一种方法来估计学习状态,它不引入额外的推断过程,也不需要额外的验证集。一个高阈值可以过滤掉有噪声的伪标签,只留下高质量的标签,可以大大降低确认偏差。因此,关键假设是,当阈值较高时,一个类的学习效果可以通过预测落入该类且高于阈值的样本数量来反映。即预测置信度达到阈值的样本越少的类,学习难度越大或学习状态越差,表示为:

其中,σt(c)反映了c类在时间步t时的学习效果。pm,t(y|un)为模型对时间步t时的无标记数据un的预测,N为无标记数据总数。当未标记数据集达到平衡(即属于不同类别的未标记数据数量相等或接近)时,σt(c)越大,表示估计的学习效果越好。通过对σt(c)应用以下归一化,使其范围在0到1之间,然后可以使用它缩放固定阈值τ:

这种归一化方法的一个特点是,学习最好的类的βt(c)等于1,导致其灵活的阈值等于τ。这是可取的。对于难以学习的类,降低阈值,鼓励在这些类中学习更多的训练样本。这也提高了数据利用率。随着学习的进行,学习型班级的门槛会提高,从而选择性地挑选出质量更高的样本。最终,当所有类别都达到可靠的精度时,阈值将全部接近τ。请注意,阈值并不总是增长,如果未标记的数据在以后的迭代中被分类到不同的类中,阈值也可能会降低。这个新的阈值用于计算FlexMatch中的无监督损失,它可以表述为:

最后,可以将FlexMatch中的损失表述为有监督损失和无监督损失的加权组合(λ):

Threshold warm-up(阈值热身)

在实验中注意到,在训练的早期,模型可能会根据参数初始化的不同,盲目地将大部分未标记的样本预测到某一类(即。,更有可能存在确认偏误)。因此,在这个阶段,估计的学习状态可能不可靠。因此,引入一个热身过程,将公式中的分母改写为:

其中N−∑C C =1 σt(C)可视为未使用的未标记数据的数量。这确保了在训练开始时,所有估计的学习效果都从0逐渐上升,直到未使用的未标记数据的数量不再占主导地位。这个周期的持续时间取决于数据集的未标记数据量和学习难度。实际上,这样的预热过程非常容易实现,我们可以添加一个额外的类来表示未使用的未标记数据。因此,计算的分母简单地转换为在c + 1个类中寻找最大值。

Non-linear mapping function(非线性映射函数)

柔性阈值是通过线性映射归一化估计的学习效果确定的。但在真实训练过程中,它可能不是最合适的映射,βt(c)的增减可能在早期出现较大的跳跃,模型的预测仍然不稳定;而只做小波动后的班是在中后期训练阶段学好的。因此,当βt(c)较大时,弹性阈值越敏感越好,反之亦然。

我们提出一个非线性映射函数,使阈值在βt(c)均匀从0到1范围内具有非线性增长曲线,公式如下:

其中M(·)为非线性映射函数。显然,通过设置M为恒等函数,可以将式(7)视为一种特例。映射函数M应该是单调递增的,并且最大值不大于1/τ(否则弹性阈值可以大于1并过滤掉所有样本)。为了避免引入额外的超参数(例如,灵活阈值的下限),考虑映射函数的范围从0到1,以便灵活阈值的范围从0到τ。

单调递增凸函数使阈值在βt(c)较小时增长缓慢,随着βt(c)的增大而变得更加敏感。因此,我们直观地选择了具有上述性质的凸函数

进行实验。

实验

  • 2
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
boosting-crowd-counting-via-multifaceted-attention是一种通过多方面注意力提升人群计数的方法。该方法利用了多个方面的特征来准确估计人群数量。 在传统的人群计数方法中,往往只关注人群的整体特征,而忽略了不同区域的细节。然而,不同区域之间的人群密度可能存在差异,因此细致地分析这些区域是非常重要的。 该方法首先利用卷积神经网络(CNN)提取图像的特征。然后,通过引入多个注意力机制,分别关注图像的局部细节、稀疏区域和密集区域。 首先,该方法引入了局部注意力机制,通过对图像的局部区域进行加权来捕捉人群的局部特征。这使得网络能够更好地适应不同区域的密度变化。 其次,该方法采用了稀疏区域注意力机制,它能够识别图像中的稀疏区域并将更多的注意力放在这些区域上。这是因为稀疏区域往往是需要重点关注的区域,因为它们可能包含有人群密度的极端变化。 最后,该方法还引入了密集区域注意力机制,通过提取图像中人群密集的区域,并将更多的注意力放在这些区域上来准确估计人群数量。 综上所述,boosting-crowd-counting-via-multifaceted-attention是一种通过引入多个注意力机制来提高人群计数的方法。它能够从不同方面细致地分析图像,并利用局部、稀疏和密集区域的特征来准确估计人群数量。这个方法通过考虑人群分布的细节,提供了更精确的人群计数结果。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值