解决医学图像分类中,类别增量学习的不平衡问题
  • 提出背景
  • 解法拆解
  • 子解法1: CIL平衡分类损失
  • 子解法2: 分布边缘损失
  • 类增量学习算法流程



 


提出背景

论文: Addressing Imbalance for Class Incremental Learning in Medical Image Classificatio

类增量学习是指模型在保持对既有类别识别能力的同时,逐步引入新的类别进行学习的过程。

这在医学领域尤为重要,因为医学数据的分布经常会因为新病种的出现、治疗手段的更新或患者数据的变化而发生变化。

传统的深度学习方法往往假设数据分布是固定不变的,但这在现实中很少成立。

我们需要在原有的模型基础上加入新的类别,而不是从头开始训练一个全新的模型。

在不采取特别措施的情况下,当模型开始学习新的疾病类别时,可能会出现“灾难性遗忘”,即模型在学习新知识的同时,遗忘了原有的知识(如旧知识的识别能力下降)。

解决医学图像分类中,类别增量学习的不平衡问题_数据

  1. 步骤1:在这一步中,模型1(Model 1)被训练来识别三个类别:无糖尿病视网膜病变(No DR)、轻度(Mild)和中度(Moderate)的糖尿病视网膜病变。
    模型训练完成后,能够对这三种状态进行预测。
  2. 步骤2:在这一步中,新的类别——严重(Severe)的糖尿病视网膜病变被引入。
    现有的模型(现在称为模型2)在保持原有功能的基础上,增加了对新类别的识别能力。注意,在这一步中,只提供严重病变的训练数据。
  3. 步骤3:在这一步中,又引入了另一个新类别——增殖性(Proliferative)的糖尿病视网膜病变。
    模型3在前两个模型的基础上继续训练,以包括这个新类别的识别。

在每个步骤中,模型都需要能够识别所有至今为止见过的类别。

这展示了类增量学习的挑战:每次增加新的类别时,模型不仅要学习新的信息,同时还要保持对旧信息的记忆,避免灾难性遗忘。

 

因此,作者提出了一种新的方法来应对数据分布的变化,特别是新旧类别之间的数据不平衡问题,这种不平衡是导致“灾难性遗忘”(即模型忘记旧知识)的一个主要原因。

为了解决这个问题,作者提出了两种新的损失函数:

  1. CIL平衡分类损失:这个方法通过调整损失函数中的logits(即模型输出的原始预测值),特别强调罕见类别,以避免模型偏向于新类别或样本较多的类别。
  2. 分布边缘损失:这个方法通过优化模型的特征空间,减少不同类别间的重叠,并实现类内的紧密聚集,以改善类别间的界限。

通过这些方法,模型不仅能更好地处理新旧数据之间的不平衡,还能有效防止在学习新类别时对旧类别的遗忘。

 

解法拆解

在医学图像分类中实现类增量学习,即让模型能够逐步学习新的数据类别,同时保持对旧数据类别的识别能力。

问题

  1. 灾难性遗忘:在学习新类别的过程中,模型容易遗忘已学习的旧类别。
  2. 类别间的重叠:新旧类别在特征空间中易于发生重叠,导致分类效果下降。

解法 = CIL-平衡分类损失(因为类别不平衡) + 分布边缘损失(因为特征空间重叠)

子解法1: CIL平衡分类损失
  • 特征: 类别不平衡导致模型对频繁出现的类别过度敏感。
  • 之所以使用CIL平衡分类损失,是因为需要调整模型对不同频率类别的关注度,特别是增强对少数类别的学习效果。
  • 例如,如果某疾病在训练数据中出现得较少,这种损失函数能够通过增加该类别的权重,帮助模型更好地学习和记住这种较少见的疾病。
子解法2: 分布边缘损失
  • 特征: 类别在特征空间的分布容易重叠。
  • 之所以使用分布边缘损失,是因为需要在特征空间中清晰地分隔旧类和新类,避免它们之间的混淆。
  • 例如,在处理不同阶段的糖尿病视网膜病变时,这种损失能够确保模型在特征层面上区分开轻度和重度病变,避免将重度病变误判为轻度。

解决医学图像分类中,类别增量学习的不平衡问题_特征空间_02

  • 场景设置:初始模型已经能够识别无病变、轻度和中度病变,现在需要引入严重和增殖性病变这两个新类别。
  • 图(a)
    传统的边缘损失方法,它强调了锚点嵌入(ha)与正类嵌入(wp)之间的相似性要大于锚点与负类嵌入(wn)之间的相似性,但没有考虑类别分布的分离。
    在原有的基础上,传统的边缘损失方法帮助模型区分轻度(正类嵌入)和中度(负类嵌入)病变,但未特别处理新加入的严重和增殖性病变类别,可能导致新类别与旧类别混淆。
  • 图(b)
    引入了分布边缘损失,目标是将ha从负类分布推远,而不仅仅是wn,从而减少特征空间的重叠。
    通过分布边缘损失,调整模型以将严重和增殖性病变从中度病变的特征分布中显著分离出来,避免两者在特征空间中的重叠。
  • 图©:说明传统边缘损失未能充分减小类内距离,可能导致ha偏离其真实类别中心。
  • 图(d)
    展示了分布边缘损失确保ha保持在其对应类别分布内,增强了类内紧凑性。
    进一步确保严重和增殖性病变的特征嵌入不仅与中度病变区分开,而且紧密围绕其真实特征分布,提高了对这些阶段的诊断精度和模型的类内紧凑性。

 

CIL平衡分类损失通过调整模型对不同类别的权重来解决类别不平衡问题,增强对少数类的识别。

分布边缘损失通过增加类别间的特征空间距离,减少新旧类别的重叠,以提高分类的准确性和鲁棒性。

假设我们的目标是开发一个能够随着时间逐步识别和分类糖尿病视网膜病变从无到轻微、中度、严重以及增殖性的模型。

场景描述:

  • 初始模型:训练一个模型来识别三个基本的糖尿病视网膜病变阶段:无病变(No DR)、轻度(Mild)、和中度(Moderate)。
  • 数据扩展:随后需要增加模型对更严重病变阶段的识别能力,包括严重(Severe)和增殖性(Proliferative)病变。

子解法1: CIL平衡分类损失的应用:

  • 问题:在初始的训练集中,无病变(No DR)的图像可能远多于轻度和中度病变的图像,这可能导致模型对常见类别过度敏感,而忽略较少见的类别。
  • 应用:通过CIL平衡分类损失,增加轻度和中度病变的权重,在模型训练过程中强化对这些类别的学习,从而提高模型对这些较少见病变阶段的识别能力。
  • 效果:在后续的类增量学习步骤中,模型能够更准确地识别轻度和中度病变,避免由于类别不平衡造成的诊断错误。

子解法2: 分布边缘损失的应用:

  • 问题:随着新类别(严重和增殖性病变)的引入,存在新旧类别在特征空间中重叠的风险,这可能导致模型混淆不同阶段的病变。
  • 应用:通过分布边缘损失,调整模型的特征空间,使得各个病变阶段在特征层面上更加分离和明确。
  • 效果:模型可以更清晰地区分中度、严重和增殖性病变,提高了诊断的准确性和可靠性。
类增量学习算法流程

解决医学图像分类中,类别增量学习的不平衡问题_CIL_03


用类增量学习方法来更新医学图像分类模型:

  • 步骤 1:整合当前任务数据 解决医学图像分类中,类别增量学习的不平衡问题_数据_04 和前一任务的记忆样本 解决医学图像分类中,类别增量学习的不平衡问题_类增量学习_05,为训练过程做准备。
  • 步骤 2-9:重复执行以下步骤直到达到预定义的训练周期:
  • 步骤 4:应用CIL-平衡分类损失(调整模型以减轻新旧类别数据不平衡的影响)。
  • 步骤 5:应用分布边缘损失(改进类别在特征空间的分布,减少类间重叠)。
  • 步骤 6:应用知识蒸馏损失(保持新模型与旧模型之间的一致性,减少遗忘)。
  • 步骤 8:利用所有损失的组合优化模型。

举个例子,眼底疾病:

  • 步骤 1:整合现有的无病变、轻度和中度病变图像数据以及前一阶段记忆中的样本,为引入严重和增殖性病变做准备。
  • 步骤 4:应用CIL-平衡分类损失,调整无病变、轻度和中度病变的数据权重,以弥补因严重和增殖性病变图像较少而可能导致的训练不平衡。
  • 步骤 5:应用分布边缘损失,确保严重和增殖性病变的特征在模型中清晰地与其他类别分开,减少误诊。
  • 步骤 6:应用知识蒸馏损失,帮助新模型记住在没有严重和增殖性病变类别时已经学习到的无病变、轻度和中度病变的知识,防止在学习新类别时遗忘这些旧知识。
  • 步骤 8:通过优化包含所有这些损失的总损失函数 解决医学图像分类中,类别增量学习的不平衡问题_CIL_06,更新模型,使其能够更准确地诊断所有类型的糖尿病视网膜病变。