2023港科大新作 | 新颖注意力机制有效提升医学图像小样本语义分割精度!

欢迎关注『CVHub』官方微信公众号!

Title: Few Shot Medical Image Segmentation with Cross Attention Transformer

PDF: https://arxiv.org/pdf/2303.13867

Code: coming soon…

导读

在深度学习医学图像分割领域,训练一个性能强,可以大规模部署落地的模型,往往需要大量手动标注的数据进行监督训练,其中花费的成本是非常高的。为了解决这一挑战,少样本学习(few-shot)技术有潜力从有限的几个sample中学习新类别的能力。本文提出了一种基于交叉掩码注意力Transformer的少样本医学图像分割新框架CAT-Net:通过挖掘supportquery图像之间的相关性,并限制模型仅关注有用的前景信息,来提高supportquery特征的表达能力。同时,本文还进一步设计了一个迭代细化训练框架来优化查询query图像分割。作者在三个公共数据集(Abd-CT,Abd-MRI和Card-MRI)上验证了所提出方法的有效性。

引言

我们都知道,图像分割在医学影像中是一项非常常见的任务。在工业界,绝大部分医学图像分割任务还是基于全监督进行训练(或许只因为它稳定且精度高,只要公司不断的堆高质量标注数据就可以了,这个成本公司还是愿意花的);但医学图像的标注是一项耗时长,成本高的任务,比如对于一个3D的volume图像(以CT和MRI居多),图像标注则更具有挑战性:第一,标注者需要浏览每个3D扫描的数百个2D切片进行标注,量非常大;第二,有些标注外行人还真的做不好,这往往需要一些专业医生动用专业知识进行标注;第三,这些专业医生往往还有其它很多事情要做,你让医生去标注大量数据还真得考虑可行性。。。

基于这个出发点(痛点),科研界许多研究学者为了解决手动标注所带来的挑战,开辟了各种研究方向(至于公司愿不愿意去运用此类方向做产品就看情况了):例如自监督学习,半监督学习和弱监督学习。尽管利用未标记或弱标记数据的信息,这些技术仍然需要大量的训练数据,这对于医学领域中仅有有限样本的新类别可能不可行。于是,,,研究学者开辟了few-shot这么一个研究方向:few-shot 学习范式旨在从少量标记数据(称为support)中学习模型,然后将其应用于仅有少量标记数据的新任务(称为query),而无需重新训练。

考虑到人体内的数百个器官和无数的疾病,few-shot学习为各种医学图像分割任务带来了巨大的潜力,可以在数据高效的情况下轻松地研究新任务。

大多数few-shot分割方法都在学习如何学习(旨在学习元学习器),根据support图像及其相应的分割标签的知识预测query图像的分割,而这里的核心是:如何有效地将知识从support图像传递到query图像。现有的少样本分割方法主要集中在以下两个方面:

  1. 如何学习一个元学习器
  2. 如何更好地将知识从support图像传递到query图像

尽管基于原型的方法效果已经不错,但它们通常忽略了训练过程中supportquery特征之间的交互。

因此,本文提出了一种名为CAT-Net的新型网络结构,其基于交叉注意力Transformer,可以更好地捕捉support图像和query图像之间的相关性,促进supportquery特征之间的相互作用,同时减少无用像素信息,提高特征表达能力和分割性能;此外,本文还提出了一个迭代训练框架,将先前的support分割结果反馈到注意力Transformer中,以有效增强并细化特征和分割结果。作者在三个公共数据集上验证了CAT-Net的有效性和性能优越性。

few-shot定义

少样本分割(Few-shot segmentation,FSS)的目的是通过只有少量标注的样本来分割新类别。在FSS中,数据集被分为训练集Dtrain和测试集Dtest,其中训练集包含基类别Ctrain,测试集包含新类别Ctest,且CtrainCtest没有交集。为了获得用于FSS的分割模型,采用了通常使用的episode训练方法。每个训练 / 测试 $ \mathrm{episode(S_i,Q_i)} 实例化一个 N − w a y , K − s h o t 分割学习任务。具体而言 : s u p p o r t 集 实例化一个N-way, K-shot分割学习任务。具体而言: support集 实例化一个Nway,Kshot分割学习任务。具体而言:support\mathrm{S_i} 包含 N 个类别的 K 个样本,而 q u e r y 集 包含N个类别的K个样本,而query集 包含N个类别的K个样本,而query\mathrm{Q_i}$包含同一类别的一个样本。FSS模型通过episode训练以预测query图像的新类别。在模型推理测试时,模型直接在Dtest上进行评估,无需重新训练。

方法

图1. Overview of the CAT-NET

如上图1展示了CAT-Net网络框架图,主要由三部分组成:

  1. 带有mask的特征提取MIFE子网络,用于提取初始querysupport特征以及query mask
  2. 交叉mask注意力Transformer模块CMAT,其中querysupport特征相互促进,从而提高query预测的准确性
  3. 迭代细化框架,顺序应用CMAT模块以持续促进分割性能,整个框架以端到端的方式进行训练

Mask Incorporated Feature Extraction

CAT-Net中的Mask Incorporate Feature Extraction (MIFE)子网络。MIFE子网络接收查询和支持图像作为输入,生成它们各自的特征,同时集成支持掩膜。然后,使用一个简单的分类器来预测查询图像的分割结果。具体地,首先使用一个特征提取器网络(即ResNet-50)将查询和支持图像对Iq和Is映射到特征空间中,分别产生查询图像的多层特征图Fq和支持图像的特征图Fs。接下来,将支持掩膜与Fs进行池化,然后将其扩展并与Fq和Fs进行连接。此外,还将一个先验掩膜进一步与查询特征进行连接,通过像素级相似度图来增强查询和支持特征之间的相关性。最后,使用一个简单的分类器来处理查询特征,得到查询掩膜。关于MIFE架构的更多细节可以在补充材料中找到。

Cross Masked Attention Transformer

CMAT模块包括三个主要组成部分:自注意力模块、交叉掩码注意力模块,和原型分割模块。其中,自注意力模块用于提取查询query特征和支持support特征中的全局信息;交叉掩码注意力模块用于在传递前景信息的同时消除冗余的背景信息;原型分割模块用于生成查询图像的最终预测结果。

自注意力模块

自注意力模块首先将查询特征 F 0 q F_0^q F0q 和支持特征 F 0 s F_0^s F0s 展平为1D序列,然后输入到两个相同的自注意力模块中。每个自注意力模块由一个多头注意力层和一个多层感知器MLP层组成。给定一个输入序列 S S SMHA层首先使用不同的权重将序列投影为三个序列 Q Q Q K K K V V V。然后计算注意力矩阵 A A A,公式为:

其中, d d d 是输入序列的维度。注意力矩阵通过 softmax 函数归一化,并乘以值序列 V V V 以获得输出序列 O O O。MLP层是一个简单的 1 × 1 1 \times 1 1×1 卷积层,将输出序列 O O O 映射到与输入序列 S S S 相同的维度。最终,将输出序列 O O O 添加到输入序列 S S S 中,并使用层归一化(LN)对其进行规范化,以获得最终的输出序列 X X X。自注意力对齐编码器的输出特征序列分别表示为 X q ∈ R H W × D X^q \in \mathbb{R}^{HW \times D} XqRHW×D X s ∈ R H W × D X^s \in \mathbb{R}^{HW \times D} XsRHW×D,分别对应于查询和支持特征。

交叉掩码注意力模块

用于将查询特征和支持特征按照它们的前景信息结合起来。在attention矩阵中,通过支持和查询的掩码来限制注意力区域。具体来说,给定查询特征 X q X^q Xq和来自自注意力模块的支持特征 X s X^s Xs,首先使用不同的权重将输入序列投影到三个序列 K K K Q Q Q V V V中,从而得到 K q K^q Kq Q q Q^q Qq V q V^q Vq K v K^v Kv Q v Q^v Qv V v V^v Vv。以查询特征为例,交叉注意力矩阵通过下面的公式计算得到:

其中, d d d表示查询特征的维度。这里使用的是点积注意力的形式,通过 K q K^q Kq Q s Q^s Qs的点积计算查询和支持之间的相关性。通过 d \sqrt{d} d 来缩放点积,防止在较高维度时点积的大小对注意力分布的影响过大。

原型分割模块

首先,通过一个“masked average pooling”的方法,建立每个类别的原型(prototype) p c p_c pc,用于表示该类别的特征分布。

其中, K K K是支持集中图像的数量, m ( k , x , y , c ) s m^s_{(k,x,y,c)} m(k,x,y,c)s是一个二进制掩模,表示位置 ( x , y ) (x,y) (x,y)在支持特征 k k k中是否属于类别 c c c F 1 s F_1^s F1s是支持特征。具体来说,对于每个类别 c c c,该原型是在所有支持图像中该类别对应位置的特征平均值,这样可以得到每个类别的原型 p c p_c pc

接着使用非参数度量学习方法进行分割。原型网络计算查询特征向量与原型 P = P c ∣ c ∈ C P={P_c | c \in C} P=PccC之间的距离。对所有类别应用softmax函数,生成查询分割结果:

其中cos(·)表示余弦距离,α是一个缩放因子,有助于在训练中反向传播梯度,其中α设置为20。

Iterative Refinement framework

该模块的设计目的是优化查询和支持特征以及查询分割掩模。因此可通过迭代优化的思路进行精细化分割,第i次迭代后的结果由以下公式给出:

每个步骤的细分可表示如下:

其中CMA(·)表示自注意力和交叉掩码注意力模块,Proto(·)代表原型分割模块,该公式表示通过多次迭代应用CMA和Proto模块,来获得增强的特征和优化的分割结果。

实验结果

::: block-1

作者将他们的方法与目前在腹部CT、腹部MRI和心脏MRI数据集上表现最优的方法进行了比较,使用了Dice系数作为评估指标。该比较在两种不同的实验设置(I和II)下进行。
:::

::: block-1

在 Abd-CT 和 Abd-MRI数据集上,相比于之前的最先进方法(SOTAs),这个提出的方法能够生成更加准确和详细的分割结果。
:::

::: block-1

验证了网络中各个组件的有效性:S→Q和Q→S表示CAT-Net中用于增强支持或查询特征的一条支路,而S↔Q表示将交叉注意力应用于S和Q。
:::

::: block-1

在不同迭代次数下使用CMAT模块的影响,可以观察到:增加模块数量可以提高性能,在使用5个模块时,Dice系数最大提高了2.26%。考虑到使用4和5个CMAT模块之间的性能提升不显著,因此作者选择在最终模型中使用四个CMAT模块,以在效率和性能之间取得平衡。
:::

结论

本文提出了一种用于few-shot医学图像分割的交叉注意力Transformer网络CAT-Net。通过交叉掩码注意力模块实现了查询和支持特征之间的交互,增强了特征表达能力。此外,所提出的CMAT模块可以通过迭代优化的方式以持续提高分割性能,实验结果表明了每个模块的有效性以及模型相对于SOTA方法的卓越性能。其中论文中的各个组件属于即插即用模块,可很好的嵌入到few shot任务中,以提高少样本分割的性能。


如果您也对人工智能和计算机视觉全栈领域感兴趣,强烈推荐您关注有料、有趣、有爱的公众号『CVHub』,每日为大家带来精品原创、多领域、有深度的前沿科技论文解读及工业成熟解决方案!欢迎添加小编微信号: cv_huber,备注"CSDN",加入 CVHub 官方学术&技术交流群,一起探讨更多有趣的话题!

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

CVHub

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值