5.14.6 TransMed:Transformer推进多模态医学图像分类

卷积神经网络(CNN)在医学图像分析任务中表现出了非常有竞争力的性能,例如疾病分类、肿瘤分割和病灶检测。 CNN在提取图像局部特征方面具有很大的优势。然而,由于卷积运算的局部性,它不能很好地处理长程关系。

多模态医学图像具有明确且重要的长程依赖性,有效的多模态融合策略可以极大地提高深度模型的性能。

多模态医学图像是指通过不同的采集方法或成像技术,获取的不同类型的医学图像。

现有的基于 Transformer 的网络架构需要大规模数据集才能实现更好的性能。然而,医学成像数据集相对较小,这使得将纯 Transformer 应用于医学图像分析变得困难。


1. 介绍

TransMed 结合了 CNN 和 Transformer 的优点,可以有效地提取图像的低级特征并建立模态之间的远程依赖关系。我们在两个数据集(腮腺肿瘤分类和膝盖损伤分类)上评估了我们的模型。

将 Transformer 应用于计算机视觉任务的方法。与文本相比,图像涉及更大的尺寸、噪声和冗余模态。人们提出了大量基于 Transformer 的方法,例如用于目标检测的 DETR [2]、用于语义分割的 SETR [3]、用于图像分类的 ViT [4] 和 DeiT [5]。

现有的基于深度学习的医学图像多模态融合可以分为三类:

        输入级融合、特征级融合和决策级融合

输入级融合策略将多模态图像通过多通道融合到深度网络中,学习融合特征表示,然后训练网络。输入级融合可以最大程度地保留原始图像信息并学习图像特征。但是难以建立同一患者不同模态之间的内部关系,从而导致模型性能下降。

特征级融合策略通过将每种模态的图像作为单个输入来训练单个深度网络。每个表示在网络层进行融合,最终结果被馈送到决策层以获得最终结果。特征级融合网络可以有效捕获同一患者不同模态的信息。但是每种模态都对应一个神经网络,这带来了巨大的计算成本,特别是在模态数量较多的情况下。

决策级融合将各个网络的输出进行整合,得到最终结果。决策级融合网络旨在独立地从不同模态学习更丰富的信息。但是每种模态的输出是相互独立的,因此该模型无法建立同一患者不同模态之间的内在关系。决策级融合策略也是计算密集型的。

与 CNN 相比,Transformers 可以有效地挖掘序列之间的长程关系。现有的基于Transformer的计算机视觉模型主要处理2D自然图像,如ImageNet[7]等大规模数据集。在二维图像中构造序列的方法是将图像切割成一系列的块。这种序列构造方法隐式地展现了远程依赖关系,不是很直观,因此可能很难带来显着的性能提升。

而医学图像中有更明确的序列,其中包含重要的长程依赖性和语义信息。由于人体器官的相似性,大多数视觉表示在医学图像中都是有序的。破坏这些序列将显着降低模型的性能。与自然图像相比,医学图像的序列关系(如模态、切片、块)保存了更丰富的信息。

1.1 TransMed

它结合了 CNN 和 Transformer 的优点来捕获低级特征和跨模态高级信息。TransMed 首先将多模态图像处理为序列并将其发送到 CNN,然后使用 Transformer 来学习序列之间的关系并进行预测。由于Transformer有效地对多模态图像的全局特征进行建模,TransMed在参数、运算速度和准确性方面优于现有的多模态融合方法。多模态医学图像具有更多信息序列。

  1. 首先将Transformer应用于多模态医学图像分类,并以较低的计算成本显着提高了深度模型的性能。
  2. 提出了一种新颖的多模态图像融合策略,可以利用它以更有效的方式从不同模态的图像中捕获信息。

2. 相关工作

2.2 Transformers

Transformers 使用自注意力机制作为核心模块,构建无卷积深度网络。

与 CNN 相比,Transformers 不需要人类定义的归纳偏差,并且可以很好地处理长程依赖性。

MADGAN [26] 将自注意力模块集成到生成对抗网络中,用于无监督的医学异常检测。 Liu等人[27]开发了一种具有新颖特征金字塔注意机制的CNN,用于前列腺的自动分割。

TransUNet [33] 是第一个基于 Transformer 的医学图像分割框架,它使用 Transformer 对全局上下文进行编码。 CoTr [34] 提出了一种新颖的框架,可以有效地连接 CNN 和 Transformer 以进行 3D 医学图像分割。 UNETR [35]利用纯 Transformer 作为编码器来有效捕获多尺度信息.

TransMed是第一个基于Transformers的多模态医学图像分类框架,它提供了一种新颖的多模态图像融合策略。

3 方法

多模态医学图像分类最常见的方法是直接训练CNN(例如Resnet[36])。首先,图像被编码为高级特征表示,然后融合其特征或决策。与现有方法不同,我们的方法使用 Transformer 将自注意力机制引入多模态融合策略中。

3.1 Transformers 聚合多模态特征

Transformer的重要组成部分包括自注意力(SA)、多头自注意力(MSA)和多层感知(MLP)。 Transformer 的输入包括各种嵌入和令牌。与 DeiT 略有不同,我们删除了线性投影层和蒸馏令牌。

3.1.1 自注意力(SA)

在 SA 层中,输入向量 X 首先被转换为三个不同的向量:查询矩阵 Q、键矩阵 K 和值矩阵 V:

Q=XW_q,K=XW_k,V=XW_v 

其中 W_qW_kW_v 为可训练矩阵。然后,分配给每个值的权重由 Q 和对应 K 的点积决定。不同输入向量之间的注意力函数计算如下:

Attention(Q,K,V)=Softmax(\frac{QK^\mathrm{T}}{\sqrt{d_k}})\cdot V 

其中 d_k 是向量 k 的维度。\sqrt{d_k} 提供适当的归一化,使梯度更加稳定。

3.1.2 多头自注意力

MSA是 Transformer 的核心部件。如图所示,与SA的区别在于多头机制将输入分割成许多小部分,然后并行计算每个输入的缩放点积,并拼接所有注意力输出以获得最终结果。 MSA的公式可以写成:

head_{i}=Attention(QW_{i}^{Q},KW_{i}^{K},VW_{i}^{V})

 MSA(Q,K,V)=Concat(head_1,\ldots,head_i)W^O

其中投影 W^Q_i ,W^K_i ,W^V_iW^O 是可训练参数矩阵; h 是变压器层数。 MSA的优点在于它允许模型学习不同表示子空间中的序列和位置信息。 

3.1.3 多层感知器

MLP 添加在 MSA 层之上。 MLP 由 GeLU 激活分隔的线性层组成。 MSA 和 MLP 都具有类似于残差网络的跳跃连接和层归一化。

假设第 t−1 层的表示为x_{t-1},LN表示线性归一化,第 t 层的输出可以写成如下:

 \hat{x}_{t}=MSA(LN(x_{t-1}))+x_{t-1}

x_{t}=MLP(LN(\hat{x}_{t}))+\hat{x}_{t} 

3.1.4 嵌入和标记(tokens)

输入层包含 5 个 embedding 和 token,分别是 patch embedding、position embedding、class embedding、patch token 和 class token。

patch embedding 是 CNN 中每个 patch 输出的表示,class embedding 是一个可训练的向量。为了将 patch 的空间信息和位置信息编码成 patch token,我们使用 position embedding 和 patch embedding 来保存这些信息。class embedding 没有可以添加的 patch embedding,所以 class token 和 class embedding 是等价的。

假设输入为 x,可训练向量为 W^c,position embedding 为 x_{po},patch token x_{pt} 和 class token x_{ct} 可以表示如下:

x_{pt}=Conv(x)+x_{po}

x_{ct}=W^{c} 

类标记在 Transformer 输入层之前附加到补丁标记,通过 Transformer 层,然后从全连接层输出以预测类别。

3.2 TransMed

TransMed 没有使用纯 Transformer 作为编码器,而是采用了包括 CNN 和 Transformer 的混合模型,其中 CNN 用作低级特征提取器来生成补丁嵌入。

给定一个多模态图像 x\in R^{N\times C\times D\times H\times W},其中空间分辨率为 H × W,深度为 D,通道数为 C,模态数为 N。在将其发送到之前对于CNN编码器来说,需要构造序列。

首先,结合多模态图像的通道维度、深度维度和模态维度,得到x^{\prime}\in R^{(N\times C\times D)\times H\times W}

然后,将多模态图像的三个相邻2D切片叠加以构造三通道图像 x^{\prime\prime}\in R^{(1/3\times N\times C\times D)\times3\times H\times W}。每个图像将被分为K\times K。K值越大意味着每个patch的大小越小。最后,将图像编码为一个patch   x_{input}\in R^{(1/3\times N\times C\times D\times K^{2})\times3\times(H/K)\times(W/K)}。图像序列构建完成后,输入到2D CNN。 2D CNN 的最后一个全连接层被线性投影层取代,以将向量块的特征映射到潜在的嵌入空间。 2D CNN 从图像序列中提取低级特征并对其进行初步编码。

从多模态图像中选择三个连续的二维切片(或层面),并将它们作为三个通道叠加在一起。这样做的目的是将多个模态或层面的信息融合到一个三通道的图像中,类似于RGB图像中的红、绿、蓝三个通道。

每个通道可以代表不同的模态或层面的信息,通过叠加这些通道,我们可以得到一个包含多个模态信息的复合图像。

多模态融合新方向icon-default.png?t=N7T8https://blog.csdn.net/2401_82426425/article/details/135929746?ops_request_misc=%257B%2522request%255Fid%2522%253A%2522171627178916800182723802%2522%252C%2522scm%2522%253A%252220140713.130102334..%2522%257D&request_id=171627178916800182723802&biz_id=0&utm_medium=distribute.pc_search_result.none-task-blog-2~blog~top_click~default-2-135929746-null-null.nonecase&utm_term=%E5%A4%9A%E6%A8%A1%E6%80%81%E8%9E%8D%E5%90%88&spm=1018.2226.3001.4450

4. 结果

腮腺肿瘤(PGT)数据集和 MRNet 数据集

4.1 数据集

4.1.1 腮腺肿瘤(PGT)数据集

PGT在成像特征(如肿瘤边缘、均匀性和信号强度)上表现出相当大的重叠。

使用分层随机抽样来确保每个标签至少有 5 个和 10 个正例分别出现在验证集和测试集中。训练集用于优化模型参数,验证集用于选择最佳模型。

在数据预处理阶段,我们首先执行 OTSU[45] 来提取原始图像中的前景区域。然后对同一患者的不同模态的图像进行配准,以提高前景区域的一致性。然后将每张图像重新采样为18×448×448。每个图像都是一个由36个3D MRI图像(包括T1和T2两种模态)堆叠而成的集合,每个3D图像的空间分辨率为36 ×448 × 448像素。数据增强使用随机翻转和随机噪声。随机翻转以 50% 的概率执行图像翻转。随机噪声向图像添加平均值为0、方差为0.1的高斯噪声。

4.1.2 MRNet 数据集

将数据集随机分为 1130 个训练案例、120 个验证案例和 120 个测试案例。提供的数据集包括三种 MRI 模式(T1 加权图像、T2 加权图像和质子密度加权图像)。每个图像的大小为 256 × 256,切片数量在 17 到 61 之间。

5. 讨论

首次将 Transformer 应用于多模态医学图像分类,因为它可以有效地探索 CNN 难以捕获的序列信息。在这项工作中,我们首先使用 ResNet 提取图像特征,然后使用 Transformer 捕获序列之间的长程依赖关系。基于 CNN 和 Transformer 的混合架构可以显著提高多模态医学图像分类性能。

自注意力机制不具有类似于CNN结构的归纳偏差。尽管当数据量足够大时,Transformer 结构被证明超越了归纳偏置带来的领域知识。然而,医学图像数据集较小,无法达到令人满意的性能。

  • 53
    点赞
  • 23
    收藏
    觉得还不错? 一键收藏
  • 3
    评论
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值