Transformer 在医学图像分类中的应用

这篇文章的重点是Vision Transformer(ViT)及其在实际问题中的应用。Transformer架构已经成为自然语言处理任务的事实标准。什么是Vision Transformer(ViT)?ViT架构基于图像表示,将图像表示为一组补丁。图像补丁是大小为16x16像素的非重叠图像块。例如,在分辨率为224x224的图像中,有(224 / 16) (224 / 16) = 14  14 = 196个补丁。图像补丁与NLP应用中的令牌(单词)一样对待。ViT将每个补丁表示为其像素的平坦线性投影,并使用长度为768的补丁嵌入向量进行操作(16x16x3 = 768)。下图显示了ViT的完整架构:

cf9590be342138db13009c52a2054a13.jpeg

ViT架构Transformer的主要部分包括:补丁 + 位置嵌入准备、编码器、池化(多层池化头)。

1. 补丁 + 位置嵌入是从输入图像像素中形成的矩阵,大小为196 x 768(每个补丁位置有768个值的向量,在图像大小为224 x 224时有196个补丁)。在零位置,添加了一个随机初始化的具有768个值的向量,因此补丁 + 位置嵌入是大小为197 x 768的矩阵。

2. 编码器包含一系列多头注意块,后跟标准化层和多层池化块。Transformer编码器是ViT的主要部分,它根据它们的类别从训练相似性的补丁序列。它包含一系列线性、标准化和激活层。大小为197 x 768的嵌入矩阵被转换以表示补丁之间的交互,并表示它们的类值。此矩阵的零位置行是类令牌(768个值的向量),它被用作以下池化块的输入。

3. 池化块最终将类令牌(768个值的向量)转换为感兴趣的类别的嵌入向量的输出。此块中还使用了线性和激活层。

Hugging Face中的ViT实现理解实践

让我们看看Hugging Face的基本ViT模型,使用以下代码块:

安装:

!pip install torchvision
!pip install torchinfo
!pip install -q git+https://github.com/huggingface/transformers.git

导入:

from PIL import Image
from torchinfo import summary
import torch

Google云硬盘挂载(对于Google Colab):

from google.colab import drive
drive.mount('/content/gdrive')

Cuda设备设置:

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

在下面的代码中,我查看了ViT基本模型:

from transformers import ViTConfig, ViTModel
configuration = ViTConfig()
print(configuration)

默认基本模型配置如下:

ViTConfig {
"attention_probs_dropout_prob": 0.0,
"encoder_stride": 16,
"hidden_act": "gelu",
"hidden_dropout_prob": 0.0,
"hidden_size": 768,
"image_size": 224,
"initializer_range": 0.02,
"intermediate_size": 3072,
"layer_norm_eps": 1e-12,
"model_type": "vit",
"num_attention_heads": 12,
"num_channels": 3,
"num_hidden_layers": 12,
"patch_size": 16,
"qkv_bias": true,
"transformers_version": "4.37.0.dev0"
}

通过更改配置的字段,我们可以创建一个自定义的ViT模型。让我们尝试默认的ViT基本模型:

model = ViTModel(configuration).to(device)
model.eval()

在输出中,我们可以看到所有ViT基本模型的层次结构。

b0e4eb0d8ed0ed1f94559f1f6162db1b.jpeg

模型summary:

summary(model=model, input_size=(1, 3, 224, 224), col_names=['input_size', 'output_size', 'num_params', 'trainable')

12467f848f12a2b1163beb150e862670.jpeg

ViT基本模型有大量的参数,超过8600万。让我们看看模型输出结构。我向模型发送了一个随机生成的虚假图像:

x = torch.randn((3, 224, 224))
x = torch.unsqueeze(x, 0)
y = model(x.to(device))
print(y.pooler_output.shape)
print(y.last_hidden_state.shape)

在输出中,我们可以看到:

torch.Size([1, 768)
torch.Size([1, 197, 768)

ViT基本模型的最终输出包含两部分:last_hidden_state,形状为(batch_size,197,768),这是模型.pooler部分之前模型.embeddings + model.encoder + model.layernorm(见图1)的输出;pooler_output,形状为(batch_size,768),这是模型.pooler的输出。在模型.pooler块的输入中,有一个规范化的last_hidden_state矩阵的零位置行,该矩阵在先前步骤中获得。下图说明了逐步调用块(上述)的等价性以及通过一次调用整个模型获取模型输出:

3f3a1f569fb5baa3caba9f33e9407822.jpeg

如果我们在图2中同时运行左侧代码和右侧代码,并使用相同的输入张量x进行打印,我们将看到相同的输出张量。理解ViT块及其输出结构对使用ViT进行迁移学习的解决方案的开发非常重要。Model.pooler块更改为自定义块,并使用先前块的ViT模型的推理进行训练。

Hugging Face中提供了两个用于图像分类的预训练ViT模型:

1. 在ImageNet-21k上预训练(包含1400万张图像和21k个类别的集合);

2. 在ImageNet上微调(也称为ILSVRC 2012,包含130万张图像和1000个类别的集合)。

微调在ImageNet上的分类器架构(ViTForImageClassification)包含model.pooler块而不是model.pooler块,仅包含以下线性层:

(classifier): Linear(in_features=768, out_features=1000, bias=True)

这一层的输入是规范化的last_hidden_state矩阵的零位置行。

ViT与CNN的对比

1. CNN模型从图像中获取所有局部特征,并将整体特征集合视为整个图像进行分类。它被训练为基于所有特征计算图像的类标签。ViT将图像视为一组补丁,并考虑补丁的位置。它被训练为计算补丁嵌入之间的相似性,并决定“相似”补丁的类标签,即ViT架构包含分割的概念。

2. ViT模型具有大量参数(在上面的summary中为8600万),并且需要大型数据集以获得良好的性能。CNN模型可以适应不同大小的数据集,并且可能需要相对较少的参数以获得良好的性能。

ViT如果从头开始在小型自定义数据集上表现不佳。小型自定义数据集的使用案例是使用在大型数据集上预训练的ViT模型进行迁移学习。

ViT用于X射线胸部图像分类 —— 实际实验

在这一部分,我回到了我用CNN解决的任务,并在这里进行了描述。我使用相同的X射线胸部图像数据集。该数据集包含三类图像:

ddf045d7d00a73335f52a60fdff70ca6.jpeg

我使用了统一的裁剪图像,其中包含胸部区域。裁剪图像的示例(从左到右为“正常(无肺炎)”、“肺炎-细菌”、“肺炎-病毒”):

f1c9c12042ea1660720c00f2a8ec36f3.jpeg

该数据集被分为训练集和测试集。训练集包含3000张图像,其中包括1000张“正常(无肺炎)”、1000张“肺炎-细菌”和1000张“肺炎-病毒”图像,从各自的组中随机选择。其余的图像构成测试集,因此包含2908张图像,其中包括576张“正常(无肺炎)”、1777张“肺炎-细菌”和555张“肺炎-病毒”图像。

CNN vs ViT用于2类分类器“正常(无肺炎)”/“肺炎(细菌或病毒)”

我使用X射线数据解决了以下任务:创建一个系统,该系统可以确定输入的X射线胸部图像属于“正常(无肺炎)”类还是“肺炎(细菌或病毒)”类,即2类分类器,使用ViT。我已经使用包含3个卷积块的CNN实现了解决方案,该模型在这里进行了描述。这个模型在这个数据集中是CNN模型中效果最好的。3卷积模型的summary如下:

7f7db5c697f81be233919b8958784bbc.jpeg

该模型包含348,050个参数,比ViT模型的参数要少得多。请注意,对于CNN模型,我使用分辨率为256x256的图像。在这里,我尝试使用在ImageNet-21k数据集上预训练的ViT模型,并对X射线图像进行微调。

模型1:在经过ViT处理的输入图像之后添加“小型”线性分类器

首先,我尝试最简单的解决方案,即一个线性层,其输入是从last_hidden_state矩阵中的零位置行中获取的向量,该向量具有768个值。这种最终拟合方式适用于ImageNet数据集上具有1000个类别的图像分类器。

加载预训练的ViT模型+图像处理器:

from transformers import ViTConfig, ViTModel
from transformers import AutoImageProcessor


image_processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k")
model = ViTModel.from_pretrained("google/vit-base-patch16-224-in21k")

以下代码显示了将最初以PIL Image形式存在的一张图像预处理为来自预训练ViT模型的类令牌向量(768个值)的过程:

img = <load PIL Image>


inputs = image_processor(img, return_tensors="pt")


with torch.no_grad():
    outputs = model(**inputs)


img = outputs.last_hidden_state
img = img[:, 0, :]

注意:我假设二维图像将被发送到torch DataLoader,该DataLoader会向图像批次添加批次维度。

一批经过处理的输入图像,形状为(batch_size, 1, 768),被发送到以下模型:

class ChestClassifier(nn.Module):
def __init__(self, num_classes):
super(ChestClassifier, self).__init__()
self.num_classes = num_classes
self.ln = nn.Linear(768, self.num_classes)
def forward(self, x):
x = nn.Flatten()(x)
x = self.ln(x)
return x


model1 = ChestClassifier(2).to(device)

该“小型”分类器模型的summary:

summary(model=model1, input_size=(1, 1, 768), col_names=['input_size', 'output_size', 'num_params', 'trainable')

a4b6d8be2018a4338b1e21c5138c0bce.jpeg

“小型”分类器模型仅包含1,538个参数。

我使用Adam优化器和学习率为0.001。训练使用了3000张图像,测试使用了2908张图像。我使训练批次保持平衡(每类约50%的图像)。在下图中,比较了CNN架构(其结果已在此处获得并呈现)和上述model1的结果。在下面的结果中,“Class 0”表示“正常(无肺炎)”,“Class 1”表示“肺炎(细菌或病毒)”。对于两个模型,我选择了最佳检查点:

e8b22f008cc8c191f8a4b89021fcc785.jpeg

ViT微调与“小型”线性分类器的结果明显不如CNN架构的结果。我认为这些结果的原因有以下几点:医学图像与ViT模型训练时使用的ImageNet数据非常不同,并且我的“小型”线性分类器的可训练参数数量不足以使迁移学习结果优于CNN模型结果。

如何改进模型呢?首先,我可以使用整个预训练的补丁位置状态——ViT输出的整个last_hidden_state——来微调分类器。其次,我可以尝试使用可训练参数更多的更复杂的分类器模型。

模型2:在经过ViT处理的输入图像之后添加“大型”线性分类器

与model1相比,我改变了输入PIL图像的预处理方式,以获取预训练ViT模型的整个转置last_hidden_state矩阵。该矩阵形成我的分类器模型的输入:

img = <load PIL Image>
inputs = image_processor(img, return_tensors="pt")
with torch.no_grad():
outputs = model(inputs)
img = outputs.last_hidden_state.permute(0, 2, 1)
img = img.squeeze()

注意:我使用`img.squeeze()`来去除单个图像的批次维度,因为我假设它将被发送到torch DataLoader,该DataLoader会向图像批次添加批次维度。一批经过处理的输入图像,形状为(batch_size, 768, 197),被发送到以下模型:

class ChestClassifierL(nn.Module):
def __init__(self, num_classes):
super(ChestClassifierL, self).__init__()
self.num_classes = num_classes
self.ln1 = nn.Linear(197, 256)
self.relu = nn.ReLU(inplace=True)
self.ln2 = nn.Linear(768256, self.num_classes)
def forward(self, x):
x = self.ln1(x)
x = self.relu(x)
x = nn.Flatten()(x)
x = self.ln2(x)
return x
model2 = ChestClassifierL(2).to(device)

该“大型”分类器模型的summary:

summary(model=model2, input_size=(1, 768, 197), col_names=['input_size', 'output_size', 'num_params', 'trainable')

16bc73e355b6c08368c672510872cf6c.jpeg

“大型”分类器模型包含443,906个参数。

我使用Adam优化器和学习率为0.001。下图显示了CNN架构(其结果已在此处获得并呈现)与上述model2的结果的比较。在下面的结果中,“Class 0”表示“正常(无肺炎)”,“Class 1”表示“肺炎(细菌或病毒)”。对于两个模型,我选择了最佳检查点:

8b70f8151fd8d93d33f8861c7f732bf6.jpeg

使用“大型”分类器的ViT微调显示出比CNN更好的性能!这个结果的原因不仅仅是可训练参数数量的增加,还考虑了整个补丁位置信息。分割概念对医学图像非常重要,因为它们可能包含特定问题的异常区域。

在下面,我展示了ViT在另一个分类器上的正面趋势——用于不同类型肺炎的分类器:“肺炎-细菌”和“肺炎-病毒”。我在训练集中有1000张“肺炎-细菌”图像加上1000张“肺炎-病毒”图像,并在测试集中使用1777张“肺炎-细菌”图像加上555张“肺炎-病毒”图像。因此,训练集包含2000张图像,测试集包含2332张图像。我比较了具有3个卷积块的相同CNN架构和相同的ViT加model2组合,如上述分类器“正常(无肺炎)”/“肺炎(细菌或病毒)”。在下面的结果中,“Class 0”表示“肺炎细菌”,“Class 1”表示“肺炎病毒”。对于两个模型,我选择了最佳检查点:

470fab7f3e739edaef3d8904ab069020.jpeg

我在之前的帖子中已经展示过,区分不同类型的肺炎质量较好是困难的。无论如何,上述结果显示,与CNN相比,ViT微调解决方案的性能更好。

Model3:用于自定义输入分辨率的ViT微调

在上面讨论的所有示例中,我比较了在分辨率为256x256的输入图像上训练的CNN模型与ViT微调结果,其中ViT预训练模型需要分辨率为224x224的输入图像。在本文中,我找到了在更高分辨率上进行迁移学习的解决方案:预训练模型的输出大小应根据更高分辨率的嵌入位置进行更改,然后将其发送到使用新分辨率进行微调的模型。一张224x224的图像有196个补丁,而ViT的last_hidden_state分辨率为197x768。一张256x256的图像有256个补丁,ViT的last_hidden_state分辨率应为257x768。因此,为了微调输入分辨率为256x256的ViT,我需要将last_hidden_state矩阵调整为分辨率257x768,并使用这个矩阵继续训练。

让我们在实践中尝试一下。输入PIL图像的预处理如下:

img = <load PIL Image>
inputs = image_processor(img, return_tensors="pt")
with torch.no_grad():
outputs = model(inputs)
img = outputs.last_hidden_state.permute(0, 2, 1)
# 新的补丁位置嵌入分辨率
img = transforms.Resize((768, 257))(img)
img = img.squeeze()

注意:我使用`img.squeeze()`来去除单个图像的批次维度,因为我假设它将被发送到torch DataLoader,该DataLoader会向图像批次添加批次维度。

一批经过处理的输入图像,形状为(batch_size, 768, 257),被发送到以下模型:

class ChestClassifierL256(nn.Module):
def __init__(self, num_classes):
super(ChestClassifierL256, self).__init__()
self.num_classes = num_classes
self.ln1 = nn.Linear(257, 256)
self.relu = nn.ReLU(inplace=True)
self.ln2 = nn.Linear(768256, self.num_classes)
def forward(self, x):
x = self.ln1(x)
x = self.relu(x)
x = nn.Flatten()(x)
x = self.ln2(x)
return x
model3 = ChestClassifierL256(2).to(device)

该模型的summary:

summary(model=model3, input_size=(1, 768, 257), col_names=['input_size', 'output_size', 'num_params', 'trainable')

9a1643b48ebe634c52fcd800da622309.jpeg

我已尝试将model3用于2类分类器“正常(无肺炎)”/“肺炎(细菌或病毒)”。下图显示了输入分辨率为224x224的model2与输入分辨率为256x256的model3的结果比较。在结果中,“Class 0”表示“正常(无肺炎)”,“Class 1”表示“肺炎(细菌或病毒)”。对于两个模型,我选择了最佳检查点:

abfb40a45363d0947bc2ae246c0fdb68.jpeg

分辨率变化的结果在与224x224非常不同的分辨率上更为明显。

结论

ViT推断和微调模型的适当组合可能会提高分类器的性能,即使在如医学图像这样非常特定的数据集上也是如此。

·  END  ·

HAPPY LIFE

547e3e67eef7d227fd15bc534d25f40f.png

本文仅供学习交流使用,如有侵权请联系作者删除

  • 20
    点赞
  • 27
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值