Asymmetric Loss For Multi-Label Classification
Abstract
- 在典型的多标签设置中,一幅图片平均包含几个正标签和许多负标签。这种正负不平衡主导了优化过程,并可能导致在训练期间不重视来自正标签的梯度,从而导致不良的准确性。在本文中,我们引入了一种新的非对称损失(“ASL”),它对正负样本的作用不同。该损失使得能够动态地降低权重和硬阈值容易的负样本,同时也丢弃可能错误标记的样本。我们展示了 ASL 如何平衡不同样本的概率,以及这种平衡如何转化为更好的地图得分。通过ASL,我们在多个流行的多标签数据集上获得了最先进的结果:MS-COCO、Pascal-VOC、NUS-WIDE和Open Images。我们还演示了ASL对其他任务的适用性,如单标签分类和对象检测。ASL是有效的,易于实施,并且不会增加训练时间或复杂性。实施可从以下网址获得: GitHub - Alibaba-MIIL/ASL: Official Pytorch Implementation of: “Asymmetric Loss For Multi-Label Classification”(ICCV, 2021) paper。
- 论文地址:[2009.14119] Asymmetric Loss For Multi-Label Classification
- 论文指出多标签分类中存在正负样本不平衡,导致模型训练时忽略正样本的梯度,影响准确性。目标是设计一个损失函数来平衡这种不平衡,同时处理可能的误标样本。提出了两个关键机制:不对称聚焦和概率偏移,分别处理正负样本的不同衰减因子,并通过硬阈值丢弃简单负样本和误标样本。
- 首先回顾了二元交叉熵和焦点损失,然后引入不对称聚焦,将正负样本的聚焦参数分开,即 γ+ 和 γ-,通常设置 γ->γ+ 以减少负样本的影响。接着是概率偏移,通过 max (p-m, 0) 来硬阈值处理简单负样本,当 p<m 时损失为 0,同时通过梯度分析发现这种偏移还能拒绝误标样本。ASL 结合了这两个机制,形成统一的损失函数。
- 多标签分类被拆解为 K 个二元分类任务,每个标签的损失函数为:(
L
=
−
y
L
+
−
(
1
−
y
)
L
−
L = -yL_+ - (1-y)L_-
L=−yL+−(1−y)L−)。其中,(y=1) 为正样本,(y=0) 为负样本,(
L
+
L_+
L+) 和 (
L
−
L_-
L−) 分别为正负样本的损失项。
- 焦点损失(Focal Loss)对正负样本使用相同的聚焦参数 ( γ \gamma γ),但多标签中负样本数量远多于正样本,需更强的负样本衰减。引入独立的聚焦参数 γ + \gamma+ γ+(正样本)和 γ − \gamma- γ−(负样本),通常设 γ \gamma γ- > γ \gamma γ+(如 γ \gamma γ+=0,仅对负样本聚焦: L + = ( 1 − p ) γ + log p ; L − = p γ − log ( 1 − p ) L_+ = (1-p)^{\gamma_+} \log p ; \quad L_- = p^{\gamma_-} \log (1-p) L+=(1−p)γ+logp;L−=pγ−log(1−p)。正样本:当 p 接近 1 时,( ( 1 − p ) γ + (1-p)^{\gamma_+} (1−p)γ+) 衰减慢,保留高梯度(避免正样本被忽视)。负样本:当 p 接近 0 时,( p γ − p^{\gamma_-} pγ−) 衰减快,降低简单负样本的贡献。
- 简单负样本(如 ( p ≪ 0.5 p \ll 0.5 p≪0.5)对分类几乎无价值,直接丢弃以减少噪声。引入硬阈值 m,定义偏移概率 ( p m = max ( p − m , 0 ) p_m = \max(p - m, 0) pm=max(p−m,0)),负样本损失变为:( L − = ( p m ) γ − log ( 1 − p m ) L_- = (p_m)^{\gamma_-} \log (1 - p_m) L−=(pm)γ−log(1−pm))。当 (p < m) 时,(p_m = 0),损失为 0,直接忽略简单负样本。当 (p > m) 时,按聚焦参数衰减,处理困难负样本。
- 结合上述模块,ASL 定义为:( A S L = { ( 1 − p ) γ + log p 正样本 max ( p − m , 0 ) γ − log ( 1 − max ( p − m , 0 ) ) 负样本 ASL = \begin{cases} (1-p)^{\gamma_+} \log p & \text{正样本} \\ \max(p - m, 0)^{\gamma_-} \log (1 - \max(p - m, 0)) & \text{负样本} \end{cases} ASL={(1−p)γ+logpmax(p−m,0)γ−log(1−max(p−m,0))正样本负样本)。 γ + \gamma_+ γ+:控制正样本聚焦(通常设 0,即正样本使用标准交叉熵)。 γ − \gamma_- γ−:控制负样本衰减(增大可抑制简单负样本)。m:硬阈值,决定简单负样本的丢弃边界(实验中最优值约 0.05-0.4,依赖数据集)。
- 为了进一步提高多标签分类模型的性能和泛化能力,可以从数据处理、模型架构、训练策略、损失函数等多个方面进行优化,除了现有的
CutoutPIL
和RandAugment
,还可以尝试其他数据增强方法,如 MixUp、CutMix。MixUp 通过将不同样本按一定比例混合,增加样本的多样性;CutMix 则是将不同图像的部分区域进行裁剪和拼接。可以将多种数据增强方法组合使用,例如先进行RandAugment
随机增强,再应用 MixUp 混合样本,这样可以让模型学习到更丰富的特征表示。- 特别是多标签分类中容易出现的漏标或误标问题。可以采用人工审核、多轮标注等方式提高标注的准确性。去除数据集中质量较差的样本,如模糊、损坏或与标签不匹配的图像,避免这些噪声数据对模型训练产生负面影响。
- 对于类别分布不均衡的情况,可以采用过采样少数类样本或欠采样多数类样本的方法。例如,使用 SMOTE(Synthetic Minority Over - sampling Technique)算法生成少数类的合成样本,或者随机欠采样多数类样本。在数据加载器中使用加权采样策略,使得模型在训练过程中更频繁地看到少数类样本,从而提高对少数类的识别能力。
- 在多标签分类任务中,正负样本的区分基于标签本身。对于一个样本(例如一张图像),如果某个标签的真实值为 1,则该样本在这个标签上属于正样本;如果真实值为 0,则属于负样本。以图像分类为例,假设一张图像有 “猫”“狗”“鸟” 三个标签,若图像中存在猫,则 “猫” 这个标签对应的样本为正样本;若不存在猫,则为负样本。在 ASL/src/loss_functions/losses.py 中的 AsymmetricLoss 类的 forward 方法里,通过目标标签 y 来区分正负样本.
- 难易样本的区分主要依据样本的预测概率。对于正样本而言,预测概率接近 1 的样本通常是简单样本,因为模型能够很有把握地识别出这些样本;而预测概率接近 0 的正样本则是困难样本,说明模型在识别这些样本时存在困难。对于负样本,预测概率接近 0 的是简单样本,预测概率接近 1 的是困难样本。
- 在 ASL 损失中,通过聚焦参数 gamma_pos 和 gamma_neg 以及概率偏移 clip 来处理难易样本。较大的 gamma_neg 会使简单负样本的损失权重变小,从而聚焦于困难负样本;而 clip 参数则用于裁剪负样本的概率,避免简单负样本对损失的影响过大。
Introduction
-
典型的自然图像包含多个对象和概念 ,突出了多标签分类对于现实世界任务的重要性。最近,在MS-COCO 、NUS-WIDE 、Pascal-VOC 和Open Images 等多标签基准测试中取得了显著的进步。据报道,通过图形神经网络利用标签相关性取得了显著的成功,图形神经网络表示标签关系 或基于知识先验的单词嵌入 。其他方法基于对图像部分和注意力区域建模 ,以及使用递归神经网络 。
-
典型的自然图像包含多个对象和概念,突出了多标签分类对于现实世界任务的重要性。最近,在MS-COCO 、NUS-WIDE 、Pascal-VOC 和Open Images 等多标签基准测试中取得了显著的进步。据报道,通过图形神经网络利用标签相关性取得了显著的成功,图形神经网络表示标签关系 或基于知识先验的单词嵌入。其他方法基于对图像部分和注意力区域建模,以及使用递归神经网络。
-
多标签分类的一个关键特征是当标签的总数很大时产生的固有正负不平衡。大多数图像只包含一小部分可能的标签,这意味着平均起来,每个类别的阳性样本数将远远低于阴性样本数。为了解决这个问题,提出了静态处理多标签问题中不平衡的损失函数。然而,它是专门针对长尾分布情况的。在密集对象检测中也遇到了高度不平衡,其中它源于前景与背景区域的比率。提出了一些基于重采样方法的解决方案,通过仅选择可能背景示例的子集。然而,重采样方法不适合处理多标签分类不平衡,因为每个图像包含许多标签,并且重采样不能仅改变特定标签的分布。
-
对象检测中的另一个常见解决方案是采用焦点损失,它随着标签置信度的增加而衰减损失。这将重点放在硬样本上,而降低简单样本的权重,简单样本主要与简单背景位置相关。令人惊讶的是,焦损失很少用于多标签分类,交叉熵通常是默认选择。由于在多标签分类中也会遇到高的负-正不平衡,焦点损失可能会提供更好的结果,因为它鼓励关注相关的硬负样本,这些样本大多与不包含正类别但包含一些其他混淆类别的图像相关。然而,对于多标签分类的情况,如焦点损失所提出的,同等对待阳性和阴性样本是次优的,因为它导致来自阴性样本的更多损失梯度的累积,以及来自罕见阳性样本的重要贡献的权重降低。换句话说,网络可能关注从负面样本中学习特征,而不强调从正面样本中学习特征。
-
在本文中,我们介绍了一种用于多标签分类的非对称损失(ASL ),它明确地解决了正负不平衡问题。ASL基于两个关键属性:首先,为了在保持正样本的贡献的同时关注硬负样本,我们将正样本和负样本的调制解耦,并为它们分配不同的指数衰减因子。第二,我们建议转移负样本的概率,以完全丢弃非常容易的负样本(硬阈值)。通过制定损失导数,我们证明了概率转移也能够丢弃非常硬的负样本,怀疑是误标记的,这在多标记问题中是常见的。
-
我们将ASL与常见的对称损失函数、交叉熵和焦点损失进行了比较,并使用我们的不对称公式显示了显著的地图改进。通过分析模型的概率,我们证明了ASL在平衡正负样本方面的有效性。我们还引入了一种方法,通过要求正负平均概率之间的固定间隔,在整个训练过程中动态调整不对称水平,从而简化超参数选择过程。本文的贡献可概括如下:
-
我们设计了一个新的损失函数,ASL,它明确地处理了多标签分类中的两个主要挑战:高度的正负不平衡,以及基本事实错误标记。
-
我们通过详细的梯度分析彻底研究了损耗特性。引入了用于控制损耗的不对称水平的自适应过程,以简化超参数选择的过程。
-
使用ASL,我们在四个流行的多标签基准上获得了最先进的结果。例如,我们在MS-COCO数据集上达到86.6%的mAP,超过了之前的最高结果2.8%。
-
我们的解决方案有效且易于使用。与最近的方法相比,它基于标准架构,不增加训练和推理时间,并且不需要任何外部信息。为了让 ASL 变得容易理解,我们分享我们训练过的模型和完全可复制的训练代码。
-
-
图1: (a)多标签分类的现实挑战。典型的图像包含很少的阳性样本和许多的阴性样本,导致很高的正负不平衡。此外,GT 中的缺失标注在多标注数据集中也很常见。ASL提出的解决方案。损耗特性将在第2.5节中详述
-
-
多标签分类的核心挑战是正负样本严重不平衡:每张图像平均包含少量正标签(如 2-3 个)和大量负标签(如 80-2.9=77.1 个,以 MS-COCO 为例),导致训练时负样本梯度主导,正样本特征学习被弱化。此外,多标签数据常存在负样本误标问题(如漏标或错误标注),进一步干扰模型优化。打破传统对称损失(如交叉熵、焦点损失)对正负样本的同等处理,通过差异化机制分别优化正负样本的损失权重。
- 简单负样本:通过硬阈值直接丢弃(概率偏移),避免其占用训练资源。
- 困难负样本:通过软阈值(聚焦参数)降低权重,聚焦真正有价值的难例。
- 正样本:保持高梯度贡献,确保其特征被充分学习。
-
不对称聚焦模块通过不同的指数因子调整正负样本的损失权重,概率偏移模块则通过硬阈值丢弃简单负样本,两者结合既处理了不平衡,又处理了误标问题。论文在多个数据集上验证了 ASL 的有效性,包括 MS-COCO、Pascal-VOC 等,对比了焦点损失和交叉熵,显示 ASL 在 mAP 等指标上的提升。超参数分析部分,讨论了 γ+、γ- 和 m 的影响,发现动态调整 γ- 可以简化超参数选择,并且通过概率分析证明 ASL 能平衡正负样本的平均概率。ASL 体现了对问题本质的分解,将不平衡问题拆解为简单负样本和误标样本的处理,通过不对称设计打破对称损失的局限性,强调动态调整和针对性处理。通过概率分析和自适应调整,使模型在训练中自动平衡正负样本的置信度,避免陷入局部最优。
Asymmetric Loss
-
在本节中,我们将首先回顾交叉熵和焦点损失。然后,我们将介绍提出的非对称损失(ASL)的组成部分,旨在解决多标签数据集固有的不平衡性质。我们还将分析ASL梯度,提供概率分析,并提出一种方法来动态设置训练期间的损失不对称水平。
-
import torch import torch.nn as nn class AsymmetricLoss(nn.Module): def __init__(self, gamma_neg=4, gamma_pos=1, clip=0.05, eps=1e-8, disable_torch_grad_focal_loss=True): super(AsymmetricLoss, self).__init__() # 负样本的聚焦参数,用于控制对负样本的聚焦程度 self.gamma_neg = gamma_neg # 正样本的聚焦参数,用于控制对正样本的聚焦程度 self.gamma_pos = gamma_pos # 不对称剪裁参数,用于限制负样本概率的上限 self.clip = clip # 是否禁用 PyTorch 自动梯度计算聚焦损失,以提高性能 self.disable_torch_grad_focal_loss = disable_torch_grad_focal_loss # 防止对数运算中出现零值的小常数 self.eps = eps def forward(self, x, y): """" Parameters ---------- x: input logits y: targets (multi-label binarized vector) """ # Calculating Probabilities # 对输入的 logits 应用 sigmoid 函数,得到正样本的概率 x_sigmoid = torch.sigmoid(x) xs_pos = x_sigmoid # 计算负样本的概率 xs_neg = 1 - x_sigmoid # Asymmetric Clipping # 如果 clip 参数不为 None 且大于 0,则对负样本的概率进行剪裁 if self.clip is not None and self.clip > 0: # 防止模型对负样本过于自信,将负样本概率上限限制为 1 xs_neg = (xs_neg + self.clip).clamp(max=1) # Basic CE calculation # 计算正样本的交叉熵损失 los_pos = y * torch.log(xs_pos.clamp(min=self.eps)) # 计算负样本的交叉熵损失 los_neg = (1 - y) * torch.log(xs_neg.clamp(min=self.eps)) # 将正样本和负样本的交叉熵损失相加 loss = los_pos + los_neg # Asymmetric Focusing # 如果 gamma_neg 或 gamma_pos 大于 0,则进行不对称聚焦 if self.gamma_neg > 0 or self.gamma_pos > 0: if self.disable_torch_grad_focal_loss: # 禁用 PyTorch 自动梯度计算,提高性能 torch.set_grad_enabled(False) # 计算正样本的概率(仅考虑目标为正的样本) pt0 = xs_pos * y # 计算负样本的概率(仅考虑目标为负的样本) pt1 = xs_neg * (1 - y) # pt = p if t > 0 else 1-p # 合并正样本和负样本的概率 pt = pt0 + pt1 # 计算不对称聚焦的 gamma 值 one_sided_gamma = self.gamma_pos * y + self.gamma_neg * (1 - y) # 计算不对称聚焦的权重 one_sided_w = torch.pow(1 - pt, one_sided_gamma) if self.disable_torch_grad_focal_loss: # 恢复 PyTorch 自动梯度计算 torch.set_grad_enabled(True) # 将损失乘以不对称聚焦的权重 loss *= one_sided_w # 对损失求和并取负 return -loss.sum()
-
交叉熵损失用于衡量模型预测概率分布与真实标签分布之间的差异。在多标签分类中,分别计算正样本和负样本的交叉熵损失,然后将它们相加。通过剪裁负样本的概率,防止模型对负样本过于自信,从而避免模型在训练过程中过度关注负样本。通过调整
gamma_neg
和gamma_pos
参数,对正样本和负样本进行不同程度的聚焦。较大的gamma_neg
会使模型更加关注难分类的负样本,而较小的gamma_pos
则可以减少对简单正样本的关注。
Binary Cross-Entropy and Focal Loss
-
正如通常在多标签分类中所做的那样,我们将问题简化为一系列的二元分类任务。给定K个标签,基网络为每个标签 zk 输出一个logit。每个 logit 由sigmoid函数σ(zk)独立激活。让我们将yk表示为K类的基本事实。总分类损失 Ltot 通过聚合K个标签的二进制损失获得:
-
L t o t = ∑ k = 1 K L ( σ ( z k ) , y k ) . ( 1 ) L_{tot} =\sum ^K _{k=1} L (σ(z_k), y_k). (1) Ltot=k=1∑KL(σ(zk),yk).(1)
-
每个标签的二进制损失 L 的一般形式由下式给出:
-
L = − y L + − ( 1 − y ) L − ( 2 ) L = −yL_+ − (1 − y)L_− (2) L=−yL+−(1−y)L−(2)
-
其中,y 是 GT 标签(为简洁起见,我们省略了类别指数k),L+和L分别是正负损耗部分。通过将L+和L设置为:
-
{ L + = ( 1 − p ) γ log ( p ) L − = p γ log ( 1 − p ) (3) \left\{\begin{aligned} {{L_{+}}} & {{} {{}=( 1-p )^{\gamma} \operatorname{l o g} ( p )}} \\ {{L_{-}}} & {{} {{}=p^{\gamma} \operatorname{l o g} ( 1-p )}} \\ \end{aligned} \right. \tag{3} {L+L−=(1−p)γlog(p)=pγlog(1−p)(3)
-
其中 p = σ(z) 是网络的输出概率,γ 是聚焦参数。γ = 0产生二进制交叉熵。
-
-
通过在等式3中设置γ > 0。容易否定的贡献(具有低概率,p ≪ 0.5)可以在损失函数中被降低权重,使得能够在训练期间更多地关注更难的样本。
Asymmetric Focusing
-
当使用焦点损失进行多标签训练时,有一个内在的权衡:设置高γ,以充分降低来自容易阴性的贡献,可以消除来自罕见阳性样本的梯度。我们建议分离正样本和负样本的聚焦级别。设 γ+ 和 γ- 分别为正聚焦参数和负聚焦参数。我们通过重新定义损失来获得不对称聚焦:
-
{ L + = ( 1 − p ) γ + log ( p ) L − = p γ − log ( 1 − p ) (4) \left\{\begin{aligned} {{L_{+}}} & {{} {{} {}=( 1-p )^{\gamma_{+}} \operatorname{l o g} ( p )}} \\ {{L_{-}}} & {{} {{}=p^{\gamma_{-}} \operatorname{l o g} ( 1-p )}} \\ \end{aligned} \right. \tag{4} {L+L−=(1−p)γ+log(p)=pγ−log(1−p)(4)
-
-
由于我们对强调正样本的贡献感兴趣,我们通常设置γ->γ+。不对称聚焦消除了正负样品的衰变率。通过这样做,我们可以更好地控制正样本和负样本对损失函数的贡献,并帮助网络从正样本中学习有意义的特征,尽管它们很少。
-
需要注意的是,通过静态加权因子解决类别不平衡的方法已在之前的工作中提出。然而,发现这些加权因子与聚焦参数相互作用,使得有必要一起选择两者。在实践中,甚至提出了一个有利于背景样本的加权因子(α = 0.25)。在第3节中,我们将说明简单的线性加权不足以正确处理多标签分类中的正负不平衡问题。出于这些原因,我们选择避免在我们的聚焦公式中加入静态加权因子。
Asymmetric Probability Shifting
-
当负样本的概率较低时,不对称聚焦减少了负样本对损失的贡献(软阈值)。因为多标签分类中的不平衡水平可能非常高,所以这种衰减并不总是足够的。因此,我们提出了一种额外的非对称机制,概率移位,它对非常容易的负样本执行硬阈值处理,即,当负样本的概率非常低时,它完全丢弃负样本。让我们将转移概率pm定义为:
-
p m = m a x ( p − m , 0 ) ( 5 ) p_m = max(p − m, 0) (5) pm=max(p−m,0)(5)
-
其中概率裕度m ≥ 0是可调超参数。将pm集成到等式3的 L 中。我们得到一个非对称的概率转移的焦损失:
-
L − = ( p m ) γ l o g ( 1 − p m ) , ( 6 ) L_− = (p_m)^ γ log(1 − p_m) ,(6) L−=(pm)γlog(1−pm),(6)
-
-
在图2中,我们绘制了负样本的概率移位焦损失,并将其与常规焦损失和交叉熵进行比较。从几何角度来看,我们可以看到,概率移位相当于将损失函数向右移动一个因子m,因此当p < m时,L = 0。我们稍后将通过梯度分析展示概率移位机制的另一个重要属性,它也可以剔除误标记的负样本。
-
-
图2:损失比较。对于阴性样品,将概率转移焦损失与常规焦损失和交叉熵进行比较。我们使用γ= 2,m = 0.2。
-
-
请注意,概率转移的概念并不局限于交叉熵或焦点损失,它可以用于许多损失函数。例如,线性铰链损耗 也可以看作线性损耗的(对称)概率偏移。另请注意,由于非线性sigmoid运算,逻辑移位不同于概率移位。
ASL Definition
-
为了定义不对称损耗(ASL ),我们将不对称聚焦和概率转移这两种机制整合到一个统一的公式中:
-
A S L = { L + = ( 1 − p ) γ + log ( p ) L − = ( p m ) γ − log ( 1 − p m ) (7) A S L=\left\{\begin{aligned} {} & {{} L_{+}=} & {{} ( 1-p )^{\gamma_{+}} \operatorname{l o g} ( p )} \\ {} & {{} L_{-}=} & {{} ( p_{m} )^{\gamma_{-}} \operatorname{l o g} ( 1-p_{m} )} \\ \end{aligned} \right. \tag{7} ASL={L+=L−=(1−p)γ+log(p)(pm)γ−log(1−pm)(7)
-
-
其中pm在等式5中定义。ASL允许我们应用两种类型的不对称来减少易阴性样本对损失函数的贡献——通过聚焦参数γ->γ+的软阈值和通过概率裕度m的硬阈值。可以方便地设置γ+ = 0,这样正样本将导致简单的交叉熵损失,并通过单个超参数γ控制不对称聚焦的水平。为了实验和推广,我们仍然保持γ+自由度。
Gradient Analysis
-
为了更好地理解ASL的特性和行为,我们接下来提供了损失梯度的分析,与交叉熵和焦点损失的梯度相比较。查看梯度很有用,因为在实践中,网络权重是根据损耗梯度相对于输入logit z进行更新的。ASL中负样本的损耗梯度为:
-
d L − d z = ∂ L − ∂ p ∂ p ∂ z = ( p m ) γ − [ 1 1 − p m − γ − log ( 1 − p m ) p m ] p ( 1 − p ) \begin{aligned} {{\frac{d L_{-}} {d z}}} & {{} {{} {} {}=\frac{\partial L_{-}} {\partial p} \frac{\partial p} {\partial z}}} \\ {{}} & {{} {{} {} {}=( p_{m} )^{\gamma-} \Big[ \frac{1} {1-p_{m}}-\frac{\gamma_{-} \operatorname{l o g} \big( 1-p_{m} \big)} {p_{m}} \Big] p ( 1-p )}} \\ \end{aligned} dzdL−=∂p∂L−∂z∂p=(pm)γ−[1−pm1−pmγ−log(1−pm)]p(1−p)
-
-
其中 p = 1 1 + e − z p = \frac1 {1+e^{-z}} p=1+e−z1,pm在等式5中定义。在图3中,我们展示了ASL的归一化梯度,并将其与其他损耗进行了比较。根据图3,我们可以将ASL中的负样本大致分为三个损失区域:
-
硬阈值 p < m时非常容易出现负值,应忽略不计,以便关注更硬的样本。
-
p > m的软阈值负样本,当它们的概率较低时应该被衰减。
-
误标记-非常硬的负样本,p > p*,其中 p* 被定义为 d d p ( d L d z ) = 0 \frac d {dp}(\frac {dL} {dz}) = 0 dpd(dzdL)=0 的点,被怀疑为误标记-当网络计算出负样本的概率非常大时,该样本有可能被误标记,其正确的标记应该是正的。[Learning a deep convnet for multi-label classification with partial labels]已经表明,多标记数据集易于错误标记阴性样本,这可能是因为人工标记任务很困难。当处理高度不平衡的数据集时,即使很小的负样本误标记率也会在很大程度上影响训练。因此,剔除贴错标签的样品是有益的。必须小心地进行拒绝,以允许网络传播来自实际错误分类的负样本的梯度。
-
-
图3:梯度分析。比较不同损失状态下的损失梯度与概率。CE =交叉熵(m =γ= 0),CE+PS =概率移位的交叉熵(m > 0,γ= 0),AF =不对称聚焦(m = 0,γ> 0),ASL (m > 0,γ> 0)。
-
ASL 将负样本分为三类:(p < m):硬阈值丢弃(梯度为 0)。 m < p < p ∗ m < p < p^* m<p<p∗:软衰减,聚焦困难负样本。 p > p ∗ p > p^* p>p∗:梯度反向,拒绝误标样本。
-
-
在表1中,我们根据梯度分析将ASL的属性和能力与其他损失进行了比较。我们可以看到,只有当我们将聚焦和概率裕度这两种不对称机制相结合时,我们才享有对不平衡数据集有益的所有能力和优点:非常容易的样本的硬阈值化、容易的样本的非线性衰减、错误标记的样本的拒绝以及连续的损失梯度。
-
-
表1:不同损失的性质- CE(交叉熵),AF(不对称聚焦),PS(概率偏移)。
-
Probability Analysis
-
在本节中,我们希望为我们的主张提供进一步的支持,即在多标签数据集中,使用对称损失(如交叉熵或焦点损失)对于学习正样本的特征来说是次优的。我们通过监控训练期间网络输出的平均概率来做到这一点。这使我们能够评估网络对阳性和阴性样本的置信度。低置信度表明特征没有被正确学习。我们首先将pt定义为:
-
p t = { { p ˉ i f y = 1 1 − p ˉ o t h e r w i s e (9) p_{t}=\left\{\begin{cases} {{\bar{p}}} & {{\mathrm{i f ~} y=1}} \\ {{1-\bar{p}}} & {{\mathrm{o t h e r w i s e}}} \\ \end{cases} \right. \tag{9} pt={{pˉ1−pˉif y=1otherwise(9)
-
-
其中 p ˉ \bar p pˉ 表示每次迭代时一批样本的平均概率。分别用 p t + p ^+ _t pt+ 和 p t − p ^- _t pt− 表示正负样本的平均概率,用 Δ p \Delta p Δp 表示概率差:
-
Δ p = p t + − p t − \Delta p = p^+_t-p^-_t Δp=pt+−pt−
-
均衡的训练应证明阳性和阴性样本的平均置信水平相似,即 Δ p \Delta p Δp 在整个训练过程中和训练结束时应较小。
-
-
在图4中,我们给出了三种不同损失函数(交叉熵、焦点损失和ASL)在训练过程中的平均概率 p t + p ^+ _t pt+ 和 p t − p^-_t pt−。图4展示了对不平衡数据集使用对称损失的局限性。当使用交叉熵损失或焦点损失进行训练时,我们观察到 p t − ≫ p t + p^-_t ≫p^+_t pt−≫pt+ (在训练结束时, Δ p \Delta p Δp 分别为0.23和0.1)。这意味着优化过程给了阴性样本太多的权重。相反,当用 ACL 训练时,我们可以消除差距,这意味着网络有能力适当地强调积极的样本。
-
-
图4:概率分析。在MS-COCO上使用交叉熵、焦点损失和ASL沿着训练的正样本和负样本的平均概率。对于焦点损失,我们使用γ = 2。对于ASL,我们使用γ+ = 0,γ= 2,m = 0.2。
-
-
请注意,通过在推断时降低决策阈值pth( 如果p > pth,则样本将被声明为阳性),我们可以控制精确度与召回率的权衡,并且支持高真阳性率而不是低假阴性率。然而,通过对称损失获得的大的负概率差距表明,网络不重视来自正样本的梯度,并且收敛到局部最小值,具有次优性能。我们将在第3节中验证这一说法。
Adaptive Asymmetry
-
损失函数的超参数通常通过手动调谐过程来调整。这个过程通常很繁琐,并且需要一定水平的专业知识。基于我们的概率分析,我们希望提供一种简单直观的方法,用一个可解释的控制参数来动态调整ASL的不对称水平。在上一节中,我们展示了ASL能够平衡网络,防止负样本的pt明显大于正样本(p < 0)。我们现在希望反过来,在整个训练过程中动态调整γ,以匹配所需的概率差距,用ptarget表示。我们可以通过在每批之后对γ进行简单调整来实现这一点,如等式 11 所示。
-
γ − ← γ − + λ ( ∆ p − ∆ p t a r g e t ) ( 11 ) γ− ← γ− + λ(∆p − ∆ptarget) (11) γ−←γ−+λ(∆p−∆ptarget)(11)
-
其中λ是专用步长。随着ptarget,Eq 11的增加。使我们能够在整个训练过程中动态增加不对称水平,迫使优化过程更多地关注正样本的梯度。请注意,使用类似于Eq 11的逻辑。我们还可以动态调整概率余量,或者同时调整两种不对称机制。为简单起见,我们选择在整个训练过程中只调整γ的情况,γ+ = 0,固定概率裕量很小。
-
-
附录A中的图9显示了ptarget = 0.1时,整个训练过程中的γ和p值。在10%的训练之后,网络成功地收敛到目标概率间隙,并且收敛到γ的稳定值。在下一节中,我们将分析这个动态方案的地图得分和可能的用例。
Experimental Study
-
在本节中,我们将提供全面的实验,以更好地理解不同的损耗,并展示与其他损耗相比,我们从ASL获得的改善。我们还将测试我们的自适应不对称机制,并将其与固定方案进行比较。对于测试,我们将使用众所周知的MS-COCO 数据集(完整的数据集和训练细节见第4.1.1节)。焦损失与交叉熵:在图5中,我们展示了不同焦损失γ值(γ = 0为交叉熵)下获得的图谱得分。我们可以从图5中看到,在交叉熵损失的情况下,mAP得分低于在焦点损失的情况下获得的得分(84.0%对85.1%)。对于2 ≤ γ ≤ 4,获得了具有焦点损失的最高分数。当γ低于该范围时,损失不会为简单的负样品提供足够的向下加权。当γ高于该范围时,稀有阳性样品的重量下降太多。
-
-
图5: mAP与焦损γ的关系。比较不同焦损失γ值的MS-COCO图谱得分。
-
-
不对称聚焦:在图6中,我们测试了不对称聚焦机制:对于γ的两个固定值2和4,我们沿着γ+轴给出了图分数。图6展示了不对称聚焦的有效性——当我们降低γ+(从而增加不对称程度)时,mAP 得分显著提高。
-
-
图6: mAP Vs .不对称聚焦γ+。比较不同不对称聚焦γ+值(γ= 2和γ= 4)的MS-COCO图谱得分。
-
-
有趣的是,在我们的实验中,简单地设置 γ+ = 0会得到最好的结果。这可以进一步支持对于阳性样本保持高梯度幅度的重要性。事实上,允许γ+ > 0对于存在大量易阳性样本的情况可能是有用的。请注意,我们还尝试用 γ+ < 0进行训练,以进一步扩展不对称性。然而,这些试验并不一致,因此它们没有出现在图6中。不对称概率裕度:在图7中,我们在交叉熵损失(γ = 0)和两级(对称)焦点损失(γ = 2和γ = 4)之上应用第二种不对称机制,不对称概率裕度。
-
-
图7: mAP与不对称概率余量。比较不同不对称概率裕度值的MS-COCO图得分,在对称焦点损失之上,γ = 0,2,4
-
-
我们可以从图7中看到,对于交叉熵和焦点损失,引入非对称概率裕度提高了 mAP 得分。对于交叉熵,最佳概率裕度较低,m = 0.05,与我们的梯度分析一致——具有概率裕度的交叉熵产生非平滑损失梯度,容易样本的衰减较少。因此,小概率裕度更好,它仍然能够对非常容易的样本进行硬阈值处理,并拒绝错误标记的样本。对于焦点损失,最佳概率裕度明显更高,0.3 ≤ m ≤ 0.4。这也可以通过分析损耗梯度来解释:由于焦点损耗已经具有简单样本的非线性衰减,我们需要更大的概率裕度来引入有意义的不对称。我们还可以看到,当引入非对称概率裕量时,γ = 2比γ = 4获得更好的分数,这意味着非对称概率裕量在适度的焦点损失之上工作得更好。
-
比较不同的不对称性:到目前为止,我们分别测试了每种ASL不对称性。在表2中,我们展示了当组合不对称时获得的图谱分数,并将它们与单独应用每种不对称时获得的分数进行比较。此外,我们将ASL的结果与另一种非对称机制进行了比较——如[Imbalance problems in object detection: A review]中提出的聚焦损失与线性加权相结合,这种机制在静态上有利于阳性样本。在0.5到0.95之间的值范围内搜索最佳静态权重,跳过0.05。
-
-
表2:不同不对称方法的MS-COCO图谱得分。γ+ = 0,γ-= 3时获得的聚焦图。γ = 2,m = 0.3时获得的裕度图。γ+ = 0,γ-= 4,m = 0.05时获得的组合图。
-
-
从表2中我们可以看出,当组合不对称的两个分量时,获得了最好的结果。这与我们在图3中对损耗梯度的分析相关,在图3中,我们展示了如何将两种不对称性结合起来,从而能够丢弃非常简单的样本、简单样本的非线性衰减以及拒绝可能错误标记的非常硬的负样本,这是在仅应用一种不对称时不可能得到的结果。表2还显示,使用静态称重不足以正确处理多标签分类中的高负-正不平衡,ASL对简单和困难样品进行动态操作,表现更好。自适应不对称:我们现在通过等式11中提出的程序来检查动态调整ASL不对称水平的有效性。在表3中,我们给出了在不同 Δ p t a r g e t \Delta p_{target} Δptarget 值下获得的mAP得分和 γ 的最终值。
-
从表3中我们可以看出,即使没有任何调谐,要求无偏情况ptarget = 0,与焦点损耗相比也有显著改善(85.8%对85.1%)。当使用更高的概率间隙,ptarget = 0.2时,可以获得更好的分数。有趣的是,额外关注罕见的阳性样本( Δ p t a r g e t > 0 \Delta p_{target} > 0 Δptarget>0 )比仅仅要求无偏的情况要好。
-
-
表3:适应性不对称。对于不同的p目标,从自适应不对称运行中获得的mAP得分和γ。
-
-
请注意,与固定γ的最佳ASL分数相比,从动态方案获得的最高mAP分数仍然低0.2%。这种(小)退化的一个可能原因是训练过程受到第一个时期的高度影响。在训练开始时,动态调整超参数可能是次优的,这会降低整体性能。为了补偿初始恢复迭代,动态调谐的γ趋向于收敛到更高的值,但是总得分仍然有些受阻。由于这种下降,我们在第4节中选择使用固定不对称方案。
-
尽管如此,动态方案可能对非专业用户有吸引力,因为它允许通过一个简单的可解释的超参数来控制不对称水平。此外,我们将在未来探索将该方案扩展到其他应用的方法,例如按类自适应调整γ,这对于常规的穷举搜索来说是不切实际的。
Dataset Results
- 在本节中,我们将在四个流行的多标签分类数据集上评估ASL,并将其结果与已知的最先进技术以及其他常用的损失函数进行比较。我们还将测试ASL对其他计算机视觉任务的适用性,如单标签分类和对象检测。
Multi-Label Datasets
MS-COCO
-
MS-COCO 是一种广泛用于评估计算机视觉任务(如对象检测、语义分割和图像字幕)的数据集,最近已被用于评估多标签图像分类。对于多标签分类,它包含80个不同类别的122,218幅图像,其中每幅图像平均包含2.9个标签,因此给出平均正负比: 2.9 80 − 2.9 = 0.0376 \frac {2.9} {80- 2.9} = 0.0376 80−2.92.9=0.0376。该数据集分为82,081幅图像的训练集和40,137幅图像的验证集。根据 MS-COCO 的常规设置,我们报告了以下统计数据:平均平均精度(mAP)、平均每类精度(CP)、召回率(CR)、F1 (CF1)和平均总体精度(OP)、召回率(OR) 和 F1 (OF1),以及总体统计数据和前3名的最高得分。在这些指标中,mAP、OF1和CF1是主要指标,因为它们考虑了假阴性和假阳性率。
-
在表4中,我们将ASL结果与文献中已知的最先进方法进行了比较,作为主要指标(附录B中提供了完整的训练细节和损耗超参数)。在附录C的表8中,我们列出了所有指标的结果。从表4中我们可以看出,使用ASL,我们在ResNet101(多标签分类中常用的架构)上明显优于以前的最先进方法,并将top mAP得分提高了1%以上。其他指标也显示有所改善。
-
-
表4:在MS-COCO上ASL与最先进方法的比较。所有指标都以%为单位。报告输入分辨率448的结果。
-
-
请注意,我们基于ASL的解决方案不需要修改架构,也不会增加推理和训练时间。这与之前的顶级解决方案形成对比,后者包括复杂的架构修改(注意区域,GCNs ),注入标签嵌入等外部数据,以及使用教师模型。然而,ASL是这些方法的完全补充,使用它们也可以导致分数的进一步提高,代价是增加训练的复杂性和减少吞吐量。此外,从表4中我们可以看出,使用TResNet-L等旨在匹配ResNet101 的GPU吞吐量的新架构,我们可以进一步提高mAP得分,同时仍然保持相同的训练和推理时间。这是我们提出的解决方案的另一个贡献——确定现代快速架构可以极大地促进多标签分类,并且ResNet101的常见使用可能不是最佳的。在图8中,我们测试了ASL对不同主干网的适用性,比较了三种常用架构上的不同损失函数:OFA-595 、ResNet101 和 TResNet-L。从图8中我们可以看到,在所有主干网上,ASL都优于焦点损失和交叉熵,证明了其对主干网选择的鲁棒性,以及相对于之前损失函数的优越性。
-
-
图8:测试不同主干上的不同损耗。
-
-
预训练和输入分辨率的影响:在表5中,我们比较了标准ImageNet-1K预训练和较新的ImageNet-21K预训练获得的mAP结果。我们可以看到,使用更好的预训练对结果产生了巨大的影响,将mAP得分提高了近2%。我们还在表5中显示,将输入分辨率从448提高到640可以进一步改善结果。
-
-
表5:不同ImageNet预训练方案和输入分辨率的MS-COCO mAP得分比较。所有指标都以%为单位。
-
Pascal-VOC
-
Pascal视觉对象类挑战(VOC 2007) 是另一个用于多标签识别的流行数据集。它包含来自20个对象类别的图像,平均每个图像2.5个类别。Pascal-VOC分为一个由5011幅图像组成的 trainval 集和一个由 4952 幅图像组成的测试集。我们的训练设置与用于MS-COCO的设置相同。请注意,以前关于PascalVOC的大多数工作都使用简单的ImageNet预训练,但也有一些使用额外的数据,如在MS-COCO上的预训练或使用BERT等NLP模型。为了公平的比较,我们用ImageNet预训练和附加的预训练数据(MS-COCO)分别呈现我们的结果,并将它们与相关作品进行比较。结果见表6。
-
-
表6:在Pascal-VOC数据集上将ASL与已知的最新模型进行比较。度量单位是%。
-
-
从表6中我们可以看到,无论有没有额外的预培训,ASL都在Pascal-VOC上取得了新的最先进的成绩。在附录的表9中,我们比较了Pascal-VOC上的不同损失函数,表明ASL优于交叉熵和焦点损失。
NUS-WIDE
- 在附录E中,我们带来了另一个常见的多标记数据集,NUS-WIDE 的结果。表10显示,ASL再次大幅度优于以前的顶级方法,并在全国学生联盟范围内达到新的最先进的结果。
Open Images
- Open Images (v6) 是一个大规模数据集,由900万幅训练图像和125,436幅测试图像组成。它部分标注了人工标签和机器生成的标签。开放图像的规模远远大于以前的多标签数据集,如NUS-WIDE、Pascal-VOC和MS-COCO。此外,它还包含大量未标注的标签。这使得我们可以在极端分类和高度错误标注的情况下测试ASL。完整的数据集和训练细节见附录f。据我们所知,还没有公开图像版本6的其他方法的结果。因此,我们仅将ASL与多标签分类中的其他常见损失函数进行比较。然而,我们希望我们的结果可以作为未来比较的基准。打开图像的结果显示在表7中。从表7中我们可以看出,在 Open Images 上,ASL明显优于焦损失和交叉熵,这表明ASL适用于大型数据集和极端分类情况。
-
-
表7:在Open Images V6数据集上ASL与焦损失和交叉熵的比较。
-
Additional Computer Vision Tasks
- 除了多标签分类,我们还想在其他相关的计算机视觉任务上测试 ASL。由于细粒度单标签分类和目标检测任务通常包含大部分背景或长尾情况 ,并且已知使用焦点损失会受益,因此我们选择在这些任务上测试ASL。在附录的G节和H节中,我们显示,对于这些额外的任务,ASL在相关数据集上优于焦点损失,这表明ASL不仅限于多标签分类。
Conclusion
- 本文提出了一种用于多标签分类的非对称损失(ASL)。ASL包含两个互补的不对称机制,它们对阳性和阴性样本的作用不同。通过检查ASL衍生工具,我们对损失属性有了更深的理解。通过网络概率分析,我们证明了ASL在平衡正负样本方面的有效性,并提出了一种自适应方案,可以在整个训练过程中动态调整不对称水平。大量的实验分析表明,在包括MS-COCO、Pascal-VOC、NUSWIDE和Open Images在内的流行的多标签分类基准上,ASL优于常见的损失函数和先前的最新方法。
- ASL 通过不对称聚焦和概率偏移两大机制,直击多标签分类的核心痛点 —— 正负不平衡与负样本误标,以简单高效的方式实现性能突破。其哲学在于针对性分解问题:将复杂的不平衡问题拆解为简单样本丢弃、困难样本衰减、正样本保护三个子问题,通过数学建模和实验验证层层递进。
A. Adaptive Asymmetry dynamics
-
-
图9:自适应不对称动态。当ptarget = 0.1时,整个训练过程中的γ和p值。γ+设为0,m设为0.05。
B. Multi-Label General Training Details
- 除非另有明确说明,否则我们使用以下训练程序:我们使用Adam optimizer和单周期策略训练模型60个时期,最大学习速率为2e-4。对于正则化,我们使用标准的增强技术。我们发现,常见的ImageNet统计归一化并没有改善结果,而是使用了一种更简单的归一化方法——将所有RGB通道缩放到0到1之间。根据第3节中的实验,对于ASL,我们使用γ-= 4、γ+ = 0和m = 0.05,对于焦损失,我们使用 γ = 2。我们为多标签训练默认和推荐的主干是TResNet-L。然而,为了与以前的作品进行公平的比较,我们还在一些数据集上添加了 ResNet101 主干结果 (TResNet-L和ResNet101在运行时是等效的)。
C. Comparing MS-COCO On All Common Metrics
- 在表8中,我们在MS-COCO数据集的所有通用指标上将ASL结果与已知的最先进方法进行了比较。
-
-
表8:在MS-COCO数据集上将ASL与已知的最新模型进行比较。所有指标都以%为单位。报告输入分辨率448的结果。
-
D. Comparing Loss Function on Pascal-VOC Dataset
- 在表9中,我们将ASL结果与Pascal-VOC数据集上的其他损失函数进行了比较。
-
-
表9:Pascal-VOC数据集上ASL与其他损失函数的比较。度量单位是%。
-
E. NUS-WIDE
- NUS-WIDE 数据集最初包含来自Flicker的 269,648 幅图像,这些图像已经用81个视觉概念进行了手动注释。由于一些网址已经被删除,我们只能下载220,000张图片。我们可以在以前的工作中找到新加坡国立大学范围数据集的其他变体,很难进行一对一的比较。我们建议在未来的工作中使用我们公开的variant【https://drive.google.com/file/d/0B7IzDz-4yH_HMFdiSE44R1lselE/view】进行标准化和完全公平的比较。我们使用标准的70-30训练测试分割。我们的训练设置与用于MS-COCO的设置相同。从表10中我们可以看出,ASL大幅度提高了 NUS-WIDE 已知的最先进的结果。在表11中,我们将ASL结果与NUS范围数据集上的其他损失函数进行了比较,再次表明ASL优于交叉熵和焦点损失。
-
-
表10:ACL 与 NUS-WIDE 数据集上的已知最新模型的比较。所有指标都以%为单位。
-
-
表11:在 NUS-WIDE 数据集上ASL与其他已知损失函数的比较。所有指标都以%为单位。
-
F. Open Images Training Details
- 由于flicker上缺少链接,我们只能从开放图像数据集中下载114,648张测试图像,其中包含大约5,400个唯一的标记类。为了处理开放图像数据集的部分标记方法,我们将所有未标记的标签设置为负的,具有降低的权重。由于图像数量巨大,我们在224的输入分辨率下训练了30个时期的网络,并在448的输入分辨率下对其进行了5个时期的微调。由于正负失衡的水平明显高于MS-COCO,我们增加了损失不对称的水平:对于ASL,我们用γ-= 7,γ+ = 0进行训练。对于焦点损失,我们用γ = 4进行训练。其他训练细节与用于MS-COCO的类似。
G. Fine-Grain Single-Label Classification Results
- 为了在细粒度单标签分类上测试ASL,我们选择了竞争植物标本馆2020 FGVC7挑战赛。植物标本馆2020的目标是从纽约植物园(NYBG)提供的大型长尾植物标本馆标本中识别维管植物物种。该数据集包含超过100万张图像,代表超过32,000种植物。这是一个有长尾的数据集;每个物种至少有3个标本,然而,一些物种有超过100个标本。比赛选择的衡量标准是宏观F1分数。对于焦点损失,我们用γ = 2进行训练。对于ASL,我们用γ-= 4,γ+ = 0进行训练。比赛选择的衡量标准是宏观F1分数。在表12中,我们提供了植物标本数据集上的ASL结果,并将其与常规焦损失进行比较。从表12中我们可以看到,在这个细粒度的单标签分类数据集上,ASL远远优于focal loss。注意,植物标本馆2020是CVPR-卡格尔分类竞赛。我们的ASL测试集分数将在153个团队中获得第三名。
-
-
表12:植物标本数据集上ASL与病灶损失的比较。宏观F1是比赛的官方指标。所有的结果都在一个看不见的 private-set.上。
-
H. Object Detection Results
-
为了在物体检测上测试ASL,我们使用了MSCOCO 数据集(物体检测任务),它包含118k图像的训练集和5k图像的评估集。为了训练,我们使用了流行的mm-detection 包,以及ATSS 和FCOS 中讨论的增强功能作为目标检测方法。我们用 SGD optimizer 对TResNet-M 模型进行了70个时期的训练,动量为0.9,重量衰减为0.0001,批量为48。我们使用学习率预热,初始学习率为0.01,在时期40、60减少10倍。对于ASL,我们使用 γ+ = 1,γ= 2。对于焦点损失,我们使用通用值γ = 2 。请注意,与多标签和细粒度单标签分类数据集不同,对于对象检测,γ+ = 0不是最佳解决方案。其原因可能是需要平衡在对象检测中使用的3个损失(分类、边界框和中心)的贡献。我们应该在将来进一步调查这个问题。我们的对象检测方法,FCOS ,使用3种不同类型的损失:分类(焦点损失),包围盒(IoU损失)和中心(普通交叉熵)。受大量背景样本影响的唯一成分是分类损失。因此,为了测试,我们仅用ASL替换了分类焦点损失。在表13中,我们比较了从ASL训练中获得的mAP分数和从标准焦点损失中获得的分数。从表13中我们可以看出,ASL得分高于常规焦损失,mAP得分提高了0.4%。
-
-
表13:在MS-COCO检测数据集上ASL与焦点损失的比较。
-
-
损失函数部分,src/loss_functions/losses.py 中有 AsymmetricLoss 类,这是核心,
AsymmetricLoss
(多标签)、ASLSingleLabel
(单标签)。论文中提到的不对称损失,针对正负样本不平衡问题,通过调整 gamma 参数和剪裁操作来处理。需要分析这个损失函数的实现,比如 forward 方法中如何计算正负样本的概率,应用剪裁和聚焦权重。使用 TResnet 作为 backbone,src/models/tresnet/tresnet.py 中定义了不同规模的 TResnet 模型,如 TResnetM、TResnetL 等。模型结构包括卷积块、SE 模块,以及下采样层,这些可能与论文中提到的模型架构相关。辅助函数中的 ModelEma 用于指数移动平均,提升模型性能,这也是训练中的技巧。-
ASL/ ├── src/ │ ├── loss_functions/ # 损失函数核心实现 │ │ └── losses.py # AsymmetricLoss/ASLSingleLabel │ ├── models/ # 模型架构 │ │ ├── tresnet/ # TResNet主干 │ │ └── utils/ # 模型工厂、工具函数 │ └── helper_functions/ # 辅助功能(如ModelEma、数据加载) ├── train.py # 训练主脚本(支持COCO/NUS-WIDE等数据集) ├── validate.py # 验证脚本(计算mAP) ├── MODEL_ZOO.md # 预训练模型列表(含SOTA结果) └── requirements.txt # 依赖环境(PyTorch、randaugment等)
-
-
正负样本差异化处理:
gamma_pos
(正样本聚焦参数)和gamma_neg
(负样本聚焦参数):对正样本采用低聚焦(或不聚焦),对负样本加强聚焦以抑制简单负样本。不对称剪裁(Asymmetric Clipping):通过clip
参数对负样本概率进行上界限制(如 0.05),避免模型过度自信于负样本。 -
class AsymmetricLossOptimized(nn.Module): ''' Notice - optimized version, minimizes memory allocation and gpu uploading, favors inplace operations''' def __init__(self, gamma_neg=4, gamma_pos=1, clip=0.05, eps=1e-8, disable_torch_grad_focal_loss=False): super(AsymmetricLossOptimized, self).__init__() self.gamma_neg = gamma_neg self.gamma_pos = gamma_pos self.clip = clip self.disable_torch_grad_focal_loss = disable_torch_grad_focal_loss self.eps = eps # prevent memory allocation and gpu uploading every iteration, and encourages inplace operations self.targets = self.anti_targets = self.xs_pos = self.xs_neg = self.asymmetric_w = self.loss = None def forward(self, x, y): """" Parameters ---------- x: input logits y: targets (multi-label binarized vector) """ self.targets = y self.anti_targets = 1 - y # Calculating Probabilities self.xs_pos = torch.sigmoid(x) self.xs_neg = 1.0 - self.xs_pos # Asymmetric Clipping if self.clip is not None and self.clip > 0: # 使用 inplace 操作,减少内存分配 self.xs_neg.add_(self.clip).clamp_(max=1) # Basic CE calculation self.loss = self.targets * torch.log(self.xs_pos.clamp(min=self.eps)) # 使用 inplace 操作,减少内存分配 self.loss.add_(self.anti_targets * torch.log(self.xs_neg.clamp(min=self.eps))) # Asymmetric Focusing if self.gamma_neg > 0 or self.gamma_pos > 0: if self.disable_torch_grad_focal_loss: torch.set_grad_enabled(False) self.xs_pos = self.xs_pos * self.targets self.xs_neg = self.xs_neg * self.anti_targets self.asymmetric_w = torch.pow(1 - self.xs_pos - self.xs_neg, self.gamma_pos * self.targets + self.gamma_neg * self.anti_targets) if self.disable_torch_grad_focal_loss: torch.set_grad_enabled(True) # 使用 inplace 操作,减少内存分配 self.loss *= self.asymmetric_w return -self.loss.sum()
-
使用
add_
和clamp_
等 inplace 操作,减少内存分配和 GPU 上传,提高训练效率。提前定义成员变量,避免在每次前向传播时进行内存分配。 -
class ASLSingleLabel(nn.Module): ''' This loss is intended for single-label classification problems ''' def __init__(self, gamma_pos=0, gamma_neg=4, eps: float = 0.1, reduction='mean'): super(ASLSingleLabel, self).__init__() # 正样本的聚焦参数 self.gamma_pos = gamma_pos # 负样本的聚焦参数 self.gamma_neg = gamma_neg # 标签平滑参数 self.eps = eps # 对输入的 logits 应用 logsoftmax 函数 self.logsoftmax = nn.LogSoftmax(dim=-1) # 存储目标类别的向量 self.targets_classes = [] # 损失的缩减方式,可选 'mean' 或 'sum' self.reduction = reduction def forward(self, inputs, target): # 获取类别数量 num_classes = inputs.size()[-1] # 对输入的 logits 应用 logsoftmax 函数 log_preds = self.logsoftmax(inputs) # 将目标标签转换为 one-hot 编码 self.targets_classes = torch.zeros_like(inputs).scatter_(1, target.long().unsqueeze(1), 1) # ASL weights targets = self.targets_classes anti_targets = 1 - targets # 计算正样本的概率 xs_pos = torch.exp(log_preds) # 计算负样本的概率 xs_neg = 1 - xs_pos # 仅考虑目标为正的样本的正样本概率 xs_pos = xs_pos * targets # 仅考虑目标为负的样本的负样本概率 xs_neg = xs_neg * anti_targets # 计算不对称聚焦的权重 asymmetric_w = torch.pow(1 - xs_pos - xs_neg, self.gamma_pos * targets + self.gamma_neg * anti_targets) # 将 log_preds 乘以不对称聚焦的权重 log_preds = log_preds * asymmetric_w if self.eps > 0: # label smoothing # 进行标签平滑,防止模型过拟合 self.targets_classes = self.targets_classes.mul(1 - self.eps).add(self.eps / num_classes) # loss calculation # 计算损失 loss = - self.targets_classes.mul(log_preds) # 对损失在类别维度上求和 loss = loss.sum(dim=-1) if self.reduction == 'mean': # 如果 reduction 为 'mean',则对损失求平均 loss = loss.mean() return loss
-
LogSoftmax:将输入的 logits 转换为对数概率,便于后续计算交叉熵损失。将目标标签转换为 one-hot 编码,方便计算每个类别的损失。与多标签分类中的不对称聚焦类似,通过调整
gamma_neg
和gamma_pos
参数,对正样本和负样本进行不同程度的聚焦。通过eps
参数对目标标签进行平滑处理,防止模型过拟合。 -
ASL/train.py
中整体训练逻辑包括模型的初始化、数据加载、优化器和学习率调度器的设置、训练循环、验证过程以及模型保存。使用create_model(args)
函数创建模型,并将其移动到 GPU 上。如果指定了model_path
,则加载预训练的 ImageNet 模型参数,但排除head.fc
层的参数,以适应多标签分类任务。使用CutoutPIL
和RandAugment
对训练数据进行增强,提高模型的泛化能力。使用torch.optim.Adam
作为优化器,结合add_weight_decay
函数对模型参数进行权重衰减。使用lr_scheduler.OneCycleLR
调度器,在训练过程中动态调整学习率。使用torch.cuda.amp.GradScaler
和autocast
实现混合精度训练,减少内存使用和训练时间。使用ModelEma
类对模型参数进行指数移动平均,提高模型的稳定性和泛化能力。CutoutPIL(cutout_factor=0.5)
:在图像上随机裁剪出一个矩形区域,并填充随机颜色。RandAugment()
:对图像进行随机数据增强。