Residual Attention U-Net 论文笔记

Residual Attention U-Net 论文笔记

原文地址:RAUNet: Residual Attention U-Net for Semantic Segmentation of Cataract Surgical Instruments

Abstract

手术器械的语义分割在机器人辅助手术中起着至关重要的作用。然而,由于镜面反射和等级不平衡的问题,白内障手术器械的精确分割仍然是一个挑战。本文提出了一种attention-guided网络来分割白内障手术器械。设计了一个新的注意模块来学习辨别特征,并解决镜面反射问题。它捕获全局上下文并对语义依赖进行编码,以强调关键语义特征,增强特征表示。这个注意力模块的参数很少,这有助于节省内存。因此,它可以灵活地插入其他网络。此外,还引入了一种hybrid loss(混合损耗)来训练我们的网络以解决类不平衡问题,它融合了cross entropy and logarithms of Dice loss。我们构建了一个名为Cata7的新数据集来评估我们的网络。据我们所知,这是第一个用于语义分割的白内障手术器械数据集。基于该数据集,RAUNet实现了最先进的性能,平均骰子率为97.71%,平均IOU率为95.62%。

Keywords: Attention, Semantic Segmentation, Cataract, Surgical Instrument

**关键词:**注意,语义分割,白内障,手术器械

1、Introduction

近年来,手术器械的语义分割在机器人辅助手术中得到了越来越广泛的应用。其中一个关键应用是手术器械的定位和姿态估计,这有助于手术机器人的控制。分割手术器械的潜在应用包括客观的手术技能评估、手术流程优化、报告生成等[1]这些应用可以减少医生的工作量,提高手术的安全性。

白内障手术是世界上最常见的眼科手术。每年大约执行1900万次[2]。白内障手术对医生的要求很高。计算机辅助手术可以显著降低意外手术的概率。然而,大多数与手术器械分割相关的研究都集中在内镜手术上。关于白内障手术的研究很少。据我们所知,这是第一项对白内障手术器械进行分割和分类的研究

最近,人们提出了一系列分割手术器械的方法。Luis等人[3]提出了一种基于完全卷积网络(FCN)和光流的网络,以解决手术器械的堵塞和变形等问题。RASNet[4]采用注意模块来强调目标区域,并改进特征表示。Iro等人[5]提出了一种新的U形网络,可以同时提供仪器的分割和姿态估计。Mohamed等人[6]采用了一种结合递归网络(RNN)和卷积网络(CNN)的方法来提高分割精度。综上所述,可以看出卷积神经网络在外科器械分割中取得了优异的性能。然而,上述方法都是基于内窥镜手术。白内障手术器械的语义分割与内窥镜手术有很大不同

白内障手术器械的语义分割需要面对许多挑战。与内窥镜手术不同,白内障手术需要强烈的光照条件,导致严重的镜面反射。镜面反射改变了手术器械的视觉特性。此外,白内障手术器械对于显微操作来说也很小。因此,手术器械通常只占据图像的一小部分。背景像素的数量远大于前景像素的数量,这导致了严重的类不平衡问题。因此,手术器械更容易被误认为是背景。眼组织和摄像头视野受限造成的遮挡也是一个重要问题,导致手术器械的一部分不可见。这些问题使得识别和分割手术器械变得困难。

为了解决这些问题,提出了一种新的网络——Residual Attention U-Net(RAUNet)。它引入了一个注意模块来改进特征表示。这项工作的贡献如下。

  1. 一个创新的注意力模块被称为增强注意力模块(AAM),旨在有效融合多层次特征并改进特征表示,有助于解决镜面反射问题。此外,它的参数很少,这有助于节省内存。

  2. 引入 hybrid loss来解决类间不平衡问题。它融合了cross entropy and logarithm of Dice loss,以充分利用两者的优点。

  3. 为了评估所提出的网络,我们构建了一个名为Cata7的白内障手术仪器数据集。据我们所知,这是第一个可以用于语义分割的白内障手术仪器数据集。

2、Residual Attention U-Net

2.1、Overview

高分辨率图像提供更详细的位置信息,帮助医生进行准确的手术。因此,Residual Attention U-Net (RAUNet)采用encoder-decode结构来获得高分辨率的masks。RAUNet的体系结构如图1所示。在ImageNet上预先训练的ResNet34[7]被用作编码器来提取语义特征。它有助于减少模型大小,提高推理速度。在解码器中,设计了一个新的注意模块——增强注意模块(AAM),用于融合多层次特征和捕获全局上下文。此外,使用转置卷积进行上采样以获得细化的边缘

image-20220416091212443

2.2 Augmented Attention Module

解码器通过上采样恢复位置细节。然而,上采样会导致边缘模糊和位置细节丢失。现有的一些工作[8]采用跳转连接将低级特征与高级特征连接起来,这有助于补充位置细节。但这是一种幼稚的方法。由于底层特征中缺乏语义信息,因此包含了大量无用的背景信息。该信息可能会干扰目标对象的分割。为了解决这个问题,增强注意模块被设计成捕捉高级语义信息并强调目标特征

每个通道对应一个特定的语义响应。外科器械和人体组织通常与不同的通道有关。因此,增强注意模块对语义依赖进行建模,以强调目标通道。它捕获高级特征映射中的语义信息,以及低级特征映射中的全局上下文,以编码语义依赖关系。High-level feature 包含丰富的语义信息,可用于指导low-level feature 选择重要的位置细节。此外, 底层特征映射的全局上下文编码了不同通道之间的语义关系,有助于过滤干扰信息。通过有效地利用这些信息,增强注意模块可以突出目标区域并改进特征表示。增强注意模块如图2所示。

image-20220416091925228

全局平均池用于提取全局上下文和语义信息,如等式(3)所述。它将全局信息压缩成一个关注向量,对语义依赖进行编码,有助于强调关键特征和过滤背景信息。注意向量的生成如下所述

image-20220416092159171

其中x和y分别指高级和低级特征图。g表示全局平均池。δ1表示ReLU函数,δ2表示Softmax函数。Wα,Wβ,Wñ是指1×1卷积的参数。bα,bβ,bñ指偏差。

image-20220416092325075

where k = 1, 2, …, c and x = [x1, x2, …, xc].

然后对向量进行1×1卷积和批量归一化,以进一步捕获语义依赖。采用softmax函数作为激活函数,对矢量进行归一化。底层特征映射与关注向量相乘,生成关注特征映射。最后,通过添加高级特征图来校准关注特征图。与乘法相比,加法可以减少卷积的参数,这有助于降低计算成本。此外,由于它只使用全局平均池和1×1卷积,因此该模块不会添加太多参数。全局平均池将全局信息压缩成一个向量,这也大大降低了计算成本。 (提升性能的同时,计算成本大大降低

2.3 Loss Function

手术器械的语义分割可以看作是对每个像素进行分类。因此,交叉熵损失可用于像素分类(分类常用损失函数)。它是分类中最常用的损失函数。在等式(4)中表示为H。 公式如下:

image-20220416093709521

其中w,h代表预测的宽度和高度。c是classes的数量。yijk是像素的ground truth,byijk是像素的prediction。

手术器械通常只占据图像的一小部分,这导致了严重的类别失衡问题。然而,交叉熵的性能受这个问题的影响很大。预测更倾向于将像素识别为背景。因此,手术器械可能会被部分检测到或忽略。公式(5)中定义的The Dice loss可用于解决此问题【9】。它评估预测和The Dice loss之间的相似性(在医疗分割中比较常用的损失函数),这不受前景像素与背景像素之比的影响。

image-20220416094220785

其中w,h代表预测的宽度和高度,p代表prediction,g代表ground truth。

为了有效利用这两种损失的优良特性,我们将the Dice loss with the cross entropy合并如下:

image-20220416094415472

其中α是用于平衡交叉熵损失和骰子损失的权重。D在0和1之间。log(D)将值范围从0扩展到负无穷大。当预测值与地面真值相差很大时,D很小,log(D)接近负无穷大。损失将大大增加,以惩罚这种糟糕的预测。该方法不仅可以利用骰子损失的特点,而且可以提高损失的灵敏度。

这种损失被称为Cross Entropy Log Dice(CEL-Dice)。它结合了交叉熵的稳定性Dice损失不受类别不平衡影响的特性。因此,它比交叉熵更好地解决了类不平衡问题,其稳定性比骰子损失更好。

3、Experiments

3.1 Datasets

我们构建了一个新的数据集Cata7来评估我们的网络,这是第一个用于语义分割的白内障手术器械数据集。该数据集由七个视频组成,每个视频记录一次完整的白内障手术。所有视频均来自北京同仁医院。每个视频被分割成一系列图像,分辨率为1920×1080像素。为了减少冗余,视频的采样频率从30 fps降至1 fps。此外,没有手术器械的图像也会被手动移除。每张图像都标有精确的边缘和手术器械的类型。

该数据集包含2500幅图像,分为训练集和测试集。训练集由五个视频序列组成,测试集由两个视频序列组成。表1显示了各类手术器械的数量。手术中使用了十种手术器械,如图3所示。

image-20220416094652833

image-20220416095137526

3.2 Training

在ImageNet上预先训练的ResNet34被用作编码器。预训练可以加速网络融合,提高网络性能[10]。由于计算资源有限,用于训练的每个图像的大小调整为960×544像素。该网络通过使用batch size为8的Adam进行训练。在训练期间动态调整学习率,以防止过度适应。初始学习率为4×10−5.每30次迭代,学习率乘以0.8。至于CEL-Dice中的α,经过几次实验后,它被设置为0.2。选择Dice coefficient and Intersection-Over-Union(IOU) 作为评估指标。

执行数据扩充以防止过度拟合。通过随机旋转、移位和翻转生成增强样本。通过数据增强获得800幅图像,增加特征多样性,有效防止过度拟合。批量标准化用于正则化。在解码器中,在每次卷积之后执行批标准化

3.3 Results

Ablation for augmented attention module

增强注意模块(AAM)旨在聚合多层次功能。它捕获全局上下文和语义依赖,以强调关键特征并抑制背景特征。为了验证其性能,我们进行了一系列实验。结果如表2所示。

image-20220416095617465

基本网络采用无AAM的RAUNet,平均骰子数达到95.12%,平均IOU达到91.31%。采用AAM的基础网络平均骰子数达到97.71%,平均IOU达到95.62%。通过应用AAM,平均骰子增加2.59%,平均IOU增加4.31%。此外,AAM还与GAU进行了比较[11]。使用GAU的基本网络实现了96.61%的平均骰子和93.76%的平均IOU。与使用AAM的基本网络相比,其平均Dice和平均IOU分别减少了1.10%和1.86%。此外,通过应用AAM,参数仅增加0.26M,占基础网络的1.19%。通过应用GAU,参数增加了0.60M,是AAM增加的参数量的2.31倍。这些结果表明,AAM不仅可以显著提高分割精度,而且不会增加太多的参数

为了进行直观的比较,图4(a)显示了基本网络和RAUNet的分割结果。红线表示对比区域。可以发现,基本网络的结果中存在分类错误。此外,在第三张图像中,手术器械没有完全分割。同时,RAUNet可以通过应用AAM精确分割手术器械。由RAUNet得到的mask和ground truth是一样的。这表明AAM有助于捕获高级语义特征并改进特征表示

image-20220416095729567

Comparison with state-of-the-art

为了进一步验证RAUNet的性能,将其与U-Net[8]、TernausNet[10]和LinkNet[12]进行了比较。如表3所示,RAUNet实现了最先进的性能,平均骰子为97.71%,平均IOU为95.62%,优于其他方法。U-Net[8]实现了94.99%的平均骰子和91.11%的平均IOU。TernausNet[10]和LinkNet[12]分别实现了92.98%和92.21%的平均IOU。这些方法的性能比我们的RAUNet差得多。

image-20220416100021172

图5显示了通过各种方法获得的像素精度。可以发现,主切口刀经常被U-Net、TernausNet和LinkNet错误分类。由于手术中使用的主切刀时间短,样本少,导致网络拟合不足。此外,镜头钩经常被U-Net错误分类。这是因为镜头挂钩非常薄,导致严重的等级不平衡。此外,它与其他手术器械类似。U-Net无法捕获高级语义信息,这会导致错误分类。 尽管存在这些困难,我们的方法仍然达到了高像素精度。镜头钩和主切刀的像素精度分别为90.23%和100%。这些结果表明,RAUNet可以捕捉有区别的语义特征,并解决类别不平衡问题。

image-20220416100148902

为了给出直观的结果,上述方法的分割结果如图4(b)所示。RAUNet的分割结果与ground truth值一致,明显优于其他方法。此外,RAUNet的更多结果如图6所示。

image-20220416095941715

image-20220416100220067

Verify the performance of CEL-Dice

CEL Dice用于解决class平衡问题。它结合了交叉熵的稳定性和骰子损失不受类别不平衡影响的特性。为了验证其性能,将其与交叉熵和骰子损失进行了比较。网络在测试集上获得的平均骰子和平均IOU如图7所示。它们随着训练epoch的变化而变化。研究发现,CEL-Dice能显著提高分割精度,优于Dice损耗和交叉熵。

4、Conclusion

提出了一种新的手术器械语义分割网络RAUNet。增强注意力模块旨在强调关键区域。实验结果表明,增强注意模块在增加很少参数的情况下,可以显著提高分割精度。此外,还引入了一种称为交叉熵对数骰子的混合损失,有助于解决类不平衡问题。实验证明,RAUNet在Cata7数据集上实现了最先进的性能。

注:

文中提到的Baseline模型可以作为参考,

十号文献:

Iglovikov, V., Shvets, A.: Ternausnet: U-Net with vgg11 encoder pre-trained on imagenet for image segmentation. arXiv preprint arXiv:1801.05746 (2018)

十一号文献:

Li, H., Xiong, P., An, J., Wang, L.: Pyramid attention network for semantic segmentation. arXiv preprint arXiv:1805.10180 (2018)

十二号文献:

Chaurasia, A., Culurciello, E.: Linknet: Exploiting encoder representations for efficient semantic segmentation. In: 2017 IEEE Visual Communications and Image Processing (VCIP). pp. 1–4. IEEE (2017)

  • 3
    点赞
  • 12
    收藏
    觉得还不错? 一键收藏
  • 4
    评论
Deep Residual U-Net是一种基于U-Net和ResNet的图像分割网络。它采用了U-Net的编码器-解码器结构,同时在每个解码器块中使用了ResNet的残差块(DR块)来提高特征提取能力。DR块通过引入SE(Squeeze-and-Excitation)机制来增强编码器的全局特征提取能力,同时使用1×1卷积来改变特征图的维度,以确保3×3卷积滤波器不受前一层的影响。此外,为了避免网络太深的影响,在两组Conv 1×1-Conv 3×3操作之间引入了一条捷径,允许网络跳过可能导致性能下降的层,并将原始特征转移到更深的层。Deep Residual U-Net在多个图像分割任务中都取得了优秀的性能。 以下是Deep Residual U-Net的编码器-解码器结构示意图: ```python import torch.nn as nn class DRBlock(nn.Module): def __init__(self, in_channels, out_channels): super(DRBlock, self).__init__() self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1) self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1) self.conv3 = nn.Conv2d(out_channels, out_channels, kernel_size=1) self.relu = nn.ReLU(inplace=True) self.se = nn.Sequential( nn.AdaptiveAvgPool2d(1), nn.Conv2d(out_channels, out_channels // 16, kernel_size=1), nn.ReLU(inplace=True), nn.Conv2d(out_channels // 16, out_channels, kernel_size=1), nn.Sigmoid() ) def forward(self, x): identity = x out = self.conv1(x) out = self.relu(out) out = self.conv2(out) out = self.relu(out) out = self.conv3(out) out = self.se(out) * out out += identity out = self.relu(out) return out class DRUNet(nn.Module): def __init__(self, in_channels, out_channels, init_features=32): super(DRUNet, self).__init__() features = init_features self.encoder1 = nn.Sequential( nn.Conv2d(in_channels, features, kernel_size=3, padding=1), nn.ReLU(inplace=True), nn.Conv2d(features, features, kernel_size=3, padding=1), nn.ReLU(inplace=True) ) self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2) self.encoder2 = nn.Sequential( DRBlock(features, features * 2), nn.Conv2d(features * 2, features * 2, kernel_size=3, padding=1), nn.ReLU(inplace=True), DRBlock(features * 2, features * 2), nn.Conv2d(features * 2, features * 2, kernel_size=3, padding=1), nn.ReLU(inplace=True) ) self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2) self.encoder3 = nn.Sequential( DRBlock(features * 2, features * 4), nn.Conv2d(features * 4, features * 4, kernel_size=3, padding=1), nn.ReLU(inplace=True), DRBlock(features * 4, features * 4), nn.Conv2d(features * 4, features * 4, kernel_size=3, padding=1), nn.ReLU(inplace=True) ) self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2) self.encoder4 = nn.Sequential( DRBlock(features * 4, features * 8), nn.Conv2d(features * 8, features * 8, kernel_size=3, padding=1), nn.ReLU(inplace=True), DRBlock(features * 8, features * 8), nn.Conv2d(features * 8, features * 8, kernel_size=3, padding=1), nn.ReLU(inplace=True) ) self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2) self.bottleneck = nn.Sequential( DRBlock(features * 8, features * 16), nn.Conv2d(features * 16, features * 16, kernel_size=3, padding=1), nn.ReLU(inplace=True), DRBlock(features * 16, features * 16), nn.Conv2d(features * 16, features * 16, kernel_size=3, padding=1), nn.ReLU(inplace=True) ) self.upconv4 = nn.ConvTranspose2d(features * 16, features * 8, kernel_size=2, stride=2) self.decoder4 = nn.Sequential( DRBlock(features * 16, features * 8), nn.Conv2d(features * 8, features * 8, kernel_size=3, padding=1), nn.ReLU(inplace=True), DRBlock(features * 8, features * 8), nn.Conv2d(features * 8, features * 8, kernel_size=3, padding=1), nn.ReLU(inplace=True) ) self.upconv3 = nn.ConvTranspose2d(features * 8, features * 4, kernel_size=2, stride=2) self.decoder3 = nn.Sequential( DRBlock(features * 8, features * 4), nn.Conv2d(features * 4, features * 4, kernel_size=3, padding=1), nn.ReLU(inplace=True), DRBlock(features * 4, features * 4), nn.Conv2d(features * 4, features * 4, kernel_size=3, padding=1), nn.ReLU(inplace=True) ) self.upconv2 = nn.ConvTranspose2d(features * 4, features * 2, kernel_size=2, stride=2) self.decoder2 = nn.Sequential( DRBlock(features * 4, features * 2), nn.Conv2d(features * 2, features * 2, kernel_size=3, padding=1), nn.ReLU(inplace=True), DRBlock(features * 2, features * 2), nn.Conv2d(features * 2, features * 2, kernel_size=3, padding=1), nn.ReLU(inplace=True) ) self.upconv1 = nn.ConvTranspose2d(features * 2, features, kernel_size=2, stride=2) self.decoder1 = nn.Sequential( DRBlock(features * 2, features), nn.Conv2d(features, features, kernel_size=3, padding=1), nn.ReLU(inplace=True), DRBlock(features, features), nn.Conv2d(features, features, kernel_size=3, padding=1), nn.ReLU(inplace=True) ) self.conv = nn.Conv2d(features, out_channels, kernel_size=1) def forward(self, x): enc1 = self.encoder1(x) enc2 = self.encoder2(self.pool1(enc1)) enc3 = self.encoder3(self.pool2(enc2)) enc4 = self.encoder4(self.pool3(enc3)) bottleneck = self.bottleneck(self.pool4(enc4)) dec4 = self.upconv4(bottleneck) dec4 = torch.cat((enc4, dec4), dim=1) dec4 = self.decoder4(dec4) dec3 = self.upconv3(dec4) dec3 = torch.cat((enc3, dec3), dim=1) dec3 = self.decoder3(dec3) dec2 = self.upconv2(dec3) dec2 = torch.cat((enc2, dec2), dim=1) dec2 = self.decoder2(dec2) dec1 = self.upconv1(dec2) dec1 = torch.cat((enc1, dec1), dim=1) dec1 = self.decoder1(dec1) return self.conv(dec1) ```

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 4
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值