长尾分布论文笔记:BBN

BBN: Bilateral-Branch Network with Cumulative Learning for Long-Tailed Visual Recognition

一、背景介绍

1.长尾效应

在这里插入图片描述

长尾分布比较常见,指的是数据集中少量类别占总数据集比重较大。如果使用带有长尾分布的数据集去直接训练分类网络的话,就会导致对于占比较大的类别能够较好的预测,占比较小的类别不能够较好的预测。这样做,模型整体性能就会下降。

2.问题

常见的解决长尾效应的方法有resampling(采样)和cost-sensitive re-weighting(给损失函数添加权重)——这两个统称class re-balancing方法。

  • resampling:over-sampling——重复采集少类数据,可能会导致对少数类的过拟合。under-sampling——放弃主要类别数据,削弱深度网络泛化能力。
  • re-weighting:通常在损失函数中为尾类的训练样本分配较大的权重。

通过采用class rebalancing方法可以更改训练集逼近测试集的分布并且使得训练更加关注尾部类别,这也就是为什么class rebalancing方法可以改善长尾效应。

然而,尽管重新平衡方法具有良好的最终预测,但我们认为这些方法仍然有不利影响,即它们也会在一定程度上意外地损害学习到的深层特征(即表示学习)的表示能力。具体来说,重采样存在尾部数据过拟合(通过过采样)的风险,并且在数据不平衡极度时也存在欠拟合整个数据分布(通过欠采样)的风险。对于重新加权,它会通过直接改变甚至反转数据呈现频率来扭曲原始分布。

**[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-ohk5Dczf-1677743599359)(C:\Users\21713\AppData\Roaming\Typora\typora-user-images\image-20230224215124630.png)]**

作者通过实验论证了上述观点。具体做法为首先任意选取一个分类网络,将其分为两部分:特征提取部分+分类层,接着使用三种方法训练该网络,三种方法分别为正常训练、resampling、re-weighting。这样我们便得到了分别用三种方法训练得到的分类网络,选取其特征提取部分并固定住,然后从头开始训练整个分类网络(此时特征提取部分参数不再变化,就像迁移学习中的冻结网络)。然后继续使用三种方法训练得到上图的实验结果。从实验结果不难看出当时用resampling和re-weighting使得最终模型的分类性能提高但是特征提取能力变低。(感觉这里度量特征提取部分性能的方法有点怪怪的,作者采取的方式是如果特征提取层提取到的特征性能较好,那么使用相同的训练分类层方法训练得到的整个模型性能就越好,类似于特征≈数据。故横向对比时,使用CE方法训练得到的模型性能最好)。

从这个实验结果图来看,可以得出一个结论支撑作者的模型,也就是当我们使用CE方法训练整个模型,然后使用RS/RW方法从头开始训练分类层,此时我们可以得到最好的结果。从实验结果来看,这个训练方法是明显好于直接使用RW/RS来训练整个网络的性能。

具体来说,对于每个类,我们首先通过平均此类的表示来计算质心向量。然后,计算这些表示与其质心之间的 2 距离,然后将其平均作为类内表示的紧凑性的度量。如果一个类的平均距离很小,则意味着该类的表示在特征空间中聚集得很近。我们在训练阶段将表示的 2-范数归一化为 1,以避免特征尺度的影响。我们根据分别使用交叉熵 (CE)、重新加权 (RW) 和重新采样 (RS) 学习的表征来报告结果。

基于此,我们便可以知道,想要得到一个更好的分类模型去处理长尾问题,我们便需要充分利用上述这个实验结果。

除了这个,作者还在补充材料里面利用另一种方法论证了使用re-balancing方法会导致特征提取能力变差。这次使用的评价标准是类内特征之间的距离,如果类内特征之间的距离越近,即平均距离越近,得到的特征越好。

基于上图,我们便可以得出下图这个结论。
在这里插入图片描述

综上,在本文中,作者揭示了re-balancing的机制是显着促进分类器学习,但会在一定程度上意外地损害所学深层特征的表示能力。

二、How class re-balancing strategies work?

这个部分上面两个实验已经介绍过了。即re-balancing可以提高分类器的性能但是同时会退化特征提取器性能。

三、方法介绍

1.网络结构图

在这里插入图片描述

作者已经发现使用re-balancing方法可以提高模型性能,但是使用该方法会导致特征提取层模型性能下降。故作者想要结合这两个方法的优势,来进一步提高模型性能。作者的办法是使用一种累计学习策略,先学习通用模式,然后逐渐关注尾部数据。

这里简单介绍一下这个网络的流程。首先通过两个部分共享的双分支网络,输入一个是具有长尾分布的数据集 ( x c , y c ) (x_c,y_c) (xc,yc),另一个是通过reverse操作后的数据集 ( x r , y r ) (x_r,y_r) (xr,yr)。特征提取网络采用的是残差网络,最后一个残差网络不共享权重。GAP指的是全局平均池化。最终两个层的输出特征分别为 f c f_c fc f r f_r fr。中间的 α \alpha α是一个参数,其随着epoch的增加而改变。最终的输出结果为 z z z,其表达式如下
z = α W c T f c + ( 1 − α ) W r T f r z=\alpha W_c^Tf_c+(1-\alpha)W_r^Tf_r z=αWcTfc+(1α)WrTfr
由于是分类,最终的输出要经过softmax操作得到概率值 p p p

α \alpha α的大小变化来看,随着训练的增加模型参数更新越来越依赖于红色框的分支。

损失函数的表达式如下:
L = α E ( p , y c ) + ( 1 − α ) E ( p , y r ) L=\alpha E(p,y_c)+(1-\alpha)E(p,y_r) L=αE(p,yc)+(1α)E(p,yr)
这里简单说一下,第二个分支输入数据、权重共享策略以及如何获得以及 α \alpha α的更新策略。

  • 第二个分支主要通过每个类别的概率 P i P_i Pi来进行采集, P i P_i Pi表达式如下。

P i = w i ∑ j = 1 C w j , w i = N m a x N i P_i=\frac{w_i}{\sum_{j=1}^{C}w_j},w_i=\frac{N_{max}}{N_i} Pi=j=1Cwjwi,wi=NiNmax

​ 这里 N m a x N_{max} Nmax表示所有类别中类别数量最大的样本数, N i N_i Ni表示第 i i i类样本对应的样本数。

数据生成步骤如下:1.计算出 P i P_i Pi;2.根据 P i P_i Pi随机抽取一个类 i i i;3.均匀地从第 i i i类中抽取一个样本进行替换。重复这个过程,直到获得一个batch的样本。

  • 在 BBN 中,两个分支在经济上共享相同的残差网络结构,如图 3 所示。我们使用 ResNets [12] 作为我们的骨干网络,例如 ResNet-32 和 ResNet-50。详细地说,除了最后一个残差块之外,两个分支网络共享相同的权重。共享权重有两个好处:一方面,传统学习分支(蓝色框)的良好学习表示可以有利于重新平衡分支(红色框)的学习。另一方面,共享权重将大大降低推理阶段的计算复杂度。

  • α = 1 − ( T T m a x ) 2 \alpha=1-(\frac{T}{T_{max}})^2 α=1(TmaxT)2

    这里的 T T T表示当前时刻的epoch数, T m a x T_{max} Tmax表示最大的epoch数。

2.推理阶段

这里面涉及到一个超参 α \alpha α,这里作者直接令 α = 0.5 \alpha=0.5 α=0.5

四、实验结果

1.数据集介绍

CIFAR数据集为常用实验数据集,作者用 β = N m a x N m i n \beta=\frac{N_{max}}{N_{min}} β=NminNmax来表示不平衡比例。iNaturalist数据集是真实世界的大尺度数据集,其数据类别分布极度不平衡且存在细粒度问题。

2.对比实验

在这里插入图片描述

在这里插入图片描述

这里简单说一下2X scheduler的意思是允许两倍个epoch数。

3.消融实验

在这里插入图片描述

4.观点的验证实验

在这里插入图片描述
在这里插入图片描述

注意,这个图是越平坦+方差越小越好。纵坐标表示对于分类器第 i i i类的倾向性。

疑惑点

这个 α \alpha α起作用原因是啥——个人觉得这东西很像混合学习,只不过取的混合系数分布不同。

  • 2
    点赞
  • 9
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
R语言贝叶斯网络(BBN)是一种用于建模和分析数据的统计工具。BBN是基于贝叶斯概率理论的一种图模型,在计算机科学和统计学领域中得到广泛应用。 R语言是一种专门用于数据分析和统计建模的开源编程语言,它提供了丰富的统计分析和数据可视化函数。在R语言中,我们可以使用各种包(package)来实现贝叶斯网络的建模和分析。 贝叶斯网络是一种用于表示和推断变量之间条件依赖关系的图模型。它通过有向无环图(Directed Acyclic Graph,DAG)描述变量之间的因果关系。图中的节点表示变量,边表示变量之间的条件依赖关系。 在R语言中,我们可以使用各种包来构建和分析BBN。其中,比较常用的包有bnlearn、deal、gRain等。这些包提供了丰富的函数和算法,可以帮助我们构建BBN、学习模型参数和进行推断。 在构建BBN时,我们需要先定义变量和它们之间的条件依赖关系。然后,可以使用不同的算法学习模型的参数。学习的过程可以基于数据集来进行,也可以通过专家知识来进行。学习完模型参数后,我们就可以使用该模型进行推断和预测。 通过BBN,我们可以进行许多统计分析任务,比如概率推断、预测、特征选择等。BBN还可以用于决策分析,帮助我们进行最优决策的制定。 总之,R语言贝叶斯网络(BBN)是一种用于建模和分析数据的强大工具。它可以帮助我们揭示变量之间的条件依赖关系,并用于各种统计分析和决策分析任务中。
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值