每日论文阅读1——LSKANet: Long Strip Kernel Attention Network for Robotic Surgical Scene Segmentation

[论文地址] [代码] [JHBI 24]


Abstract

手术场景分割是机器人辅助手术中的一项关键任务。然而,手术场景的复杂性,主要包括局部特征相似性(例如,在不同解剖组织之间)、术中复杂的伪影和难以区分的边界,对准确分割构成了重大挑战。为了解决这些问题,我们提出了长条核注意力网络(LSKANet),包括两个设计良好的模块,分别是双块大核注意力模块(DLKA)和多尺度亲和特征融合模块(MAFF),可以实现手术图像的精确分割。具体来说,通过在两个块中引入具有不同拓扑结构(级联和并行)的条形卷积和大内核设计,DLKA可以充分利用区域和条状手术特征,并提取视觉和结构信息,以减少由局部特征相似性引起的错误分割。在MAFF中,从多尺度特征图计算的亲和矩阵作为特征融合权重应用,这有助于通过抑制不相关区域的激活来解决伪影的干扰。此外,该文提出边界引导头(BGH)的混合损失,以帮助网络有效地分割难以区分的边界。我们在具有不同手术场景的三个数据集上评估了所提出的LSKANet。实验结果表明,我们的方法在所有三个数据集上都取得了最新的结果,mIoU分别提高了2.6%、1.4%和3.4%。此外,我们的方法与不同的backbone兼容,可以显著提高它们的分割精度。


I. Overview

在这里插入图片描述


II. Dual-block Large Kernel Attention

DLKA由两个子Block构成,分别为Anatomy Block和Instrument Block。两种block的共有特点是都用条形卷积(Strip Convolution)来实现高效大卷积核设计,区别在于Anatomy Block的条形卷积采用串行设计(先1×W再H×1),而Instrument Block的条形卷积采用并行设计(同时1×W,H×1再相加)。

Q&A
Q1: 大卷积核优缺点。
A1: 优点,扩大感受野,捕捉长距离依赖;缺点,容易引入不相关区域中的噪声。
Q2: 为什么Anatomy Block是串行而Instrument Block是并行。
A2: 串行建模的是矩形区域,并行建模的仍是条状区域,与内镜组织和手术器械的客观形状相符。

Abla Study
将DLKA模块替换为标准卷积后, 完整模型参数由72.26M降低至72.25M, mIoU由69.6降低至68.8。注意原文并未说明这里所述的"标准卷积"的尺寸。
移除整个DLKA模块,mIoU由69.6降低至68.0。
移除DLKA中的Anatomy Block,mIoU由69.6降低至69.5。这里mIoU只降了0.1是因为Anatomy类性能降低的同时Instrument类的性能有所上升。
移除DLKA中的Instrument Block,mIoU由69.6降低至69.0。

Code-AnaBlock

import torch
import torch.nn as nn
from thop import profile

class AnaBlock(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.conv0 = nn.Conv2d(dim, dim, 5, padding=2, groups=dim)
        self.conv1_1 = nn.Conv2d(dim, dim, (1, 11), padding=(0, 5), groups=dim)
        self.conv1_2 = nn.Conv2d(dim, dim, (11, 1), padding=(5, 0), groups=dim)
        self.conv2_1 = nn.Conv2d(dim, dim, (1, 21), padding=(0, 10), groups=dim)
        self.conv2_2 = nn.Conv2d(dim, dim, (21, 1), padding=(10, 0), groups=dim)
        self.conv3_1 = nn.Conv2d(dim, dim, (1, 31), padding=(0, 15), groups=dim)
        self.conv3_2 = nn.Conv2d(dim, dim, (31, 1), padding=(15, 0), groups=dim)
        self.conv3 = nn.Conv2d(dim, dim, 1)

    def forward(self, x):
        u = x.clone()
        attn = self.conv0(x)
        attn_1 = self.conv1_1(attn)
        attn_1 = self.conv1_2(attn_1)
        attn_2 = self.conv2_1(attn)
        attn_2 = self.conv2_2(attn_2)
        attn_3 = self.conv3_1(attn)
        attn_3 = self.conv3_2(attn_3)
        attn = attn + attn_1 + attn_2 + attn_3
        attn = self.conv3(attn)
        return attn * u
    
if __name__ == '__main__':
    net = AnaBlock(dim=512)
    x = torch.randn([1, 512, 22, 22])
    out = net(x)  
    flops, params = profile(net, inputs=(x, ))
    print(out.shape)  # 1, 512, 22, 22
    print('FLOPs = ' + str(round(flops/1000**3, 2)) + 'G')  # 0.16G
    print('Params = ' + str(round(params/1000**2, 2)) + 'M')  # 0.34M

Code-InsBlock

import torch
import torch.nn as nn
from thop import profile

class InsBlock(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.conv0 = nn.Conv2d(dim, dim, 5, padding=2, groups=dim)
        self.conv1_1 = nn.Conv2d(dim, dim, (1, 11), padding=(0, 5), groups=dim)
        self.conv1_2 = nn.Conv2d(dim, dim, (11, 1), padding=(5, 0), groups=dim)
        self.conv2_1 = nn.Conv2d(dim, dim, (1, 21), padding=(0, 10), groups=dim)
        self.conv2_2 = nn.Conv2d(dim, dim, (21, 1), padding=(10, 0), groups=dim)
        self.conv3_1 = nn.Conv2d(dim, dim, (1, 31), padding=(0, 15), groups=dim)
        self.conv3_2 = nn.Conv2d(dim, dim, (31, 1), padding=(15, 0), groups=dim)
        self.conv3 = nn.Conv2d(dim, dim, 1)

    def forward(self, x):
        u = x.clone()
        attn = self.conv0(x)
        attn_1_1 = self.conv1_1(attn)
        attn_1_2 = self.conv1_2(attn)
        attn_1 = attn_1_1 + attn_1_2
        attn_2_1 = self.conv2_1(attn)
        attn_2_2 = self.conv2_2(attn)
        attn_2 = attn_2_1 + attn_2_2
        attn_3_1 = self.conv3_1(attn)
        attn_3_2 = self.conv3_2(attn)
        attn_3 = attn_3_1 + attn_3_2
        attn = attn + attn_1 + attn_2 + attn_3
        attn = self.conv3(attn)
        return attn * u
    
if __name__ == '__main__':
    net = InsBlock(dim=512)
    x = torch.randn([1, 512, 22, 22])
    out = net(x)  
    flops, params = profile(net, inputs=(x, ))
    print(out.shape)  # 1, 512, 22, 22
    print('FLOPs = ' + str(round(flops/1000**3, 2)) + 'G')  # 0.16G
    print('Params = ' + str(round(params/1000**2, 2)) + 'M')  # 0.34M

Code-DLKA

import torch
import torch.nn as nn
from thop import profile
    
class DLKA(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim
        self.proj_1 = nn.Conv2d(dim, dim, 1)
        self.activation = nn.ReLU()
        self.ins_block = InsBlock(dim)
        self.ana_block = AnaBlock(dim)
        self.proj_2 = nn.Conv2d(dim, dim, 1)

    def forward(self, x):
        shorcut = x.clone()
        x = self.proj_1(x)
        x = self.activation(x)
        x_ins = self.ins_block(x)
        x_ins = self.proj_2(x_ins)
        x_ana = self.ana_block(x)
        x_ana = self.proj_2(x_ana)
        x = shorcut + x_ins + x_ana
        return x
    
if __name__ == '__main__':
    net = DLKA(dim=512)
    x = torch.randn([1, 512, 22, 22])
    out = net(x)  
    flops, params = profile(net, inputs=(x, ))
    print(out.shape)  # 1, 512, 22, 22
    print('FLOPs = ' + str(round(flops/1000**3, 2)) + 'G')  # 0.71G
    print('Params = ' + str(round(params/1000**2, 2)) + 'M')  # 1.21M

III. Multiscale Affinity Feature Fusion

MAFF作为LSKANet的解码器,主要任务是实现更好的多级特征融合。具体来说,浅层特征包含较多的伪影噪声,但拥有更多局部细节;深层特征噪声较少,但是缺失细节。直接将深层特征上采样与浅层特征融合会引发特征对齐问题。本文的做法是引入特征亲和度(Feature Affinity)来评估同一个空间位置下,不同特征图内容的相似程度。如果相似程度低,则说明该区域更高概率存在噪声,应该进行抑制;反之则应该进行增强。从这个角度看,本文的Feature Affinity实际就是一种空间注意力。

Abla Study
将MAFF中的Feature Affinity设计移除, 模型参数几乎不变, mIoU由69.1降低至68.6。

Code-Affinity

import torch
import torch.nn as nn
from thop import profile

class ConvBN(nn.Module):
    def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1):
        super().__init__()
        self.conv = nn.Conv2d(in_planes, out_planes,
                              kernel_size=kernel_size, stride=stride,
                              padding=padding, dilation=dilation, bias=False)
        self.bn = nn.BatchNorm2d(out_planes)

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        return x


class Affinity(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.Sigmoid = nn.Sigmoid()
        self.relu = nn.ReLU(inplace=True)
        self.conv = ConvBN(dim, 1, kernel_size=1, stride=1, padding=0)

    def forward(self, x1, x2):
        aff_ = self.relu(x1 + x2)
        aff = self.conv(aff_)
        aff = self.Sigmoid(aff)
        return aff

if __name__ == '__main__':
    net = Affinity(dim=512)
    x1 = torch.randn([1, 512, 22, 22])
    x2 = torch.randn([1, 512, 22, 22])
    out = net(x1, x2)  
    flops, params = profile(net, inputs=(x1, x2,))
    print(out.shape)  # 1, 1, 22, 22
    print('FLOPs = ' + str(round(flops/1000**3, 2)) + 'G')  # ~0.00G
    print('Params = ' + str(round(params/1000**2, 2)) + 'M')  # ~0.00M

Code-MAFF

import torch
import torch.nn as nn
import torch.nn.functional as F
from thop import profile

class ConvBN(nn.Module):
    def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1):
        super().__init__()
        self.conv = nn.Conv2d(in_planes, out_planes,
                              kernel_size=kernel_size, stride=stride,
                              padding=padding, dilation=dilation, bias=False)
        self.bn = nn.BatchNorm2d(out_planes)

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        return x
    
class ConvBNReLU(nn.Module):
    def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1):
        super().__init__()
        self.conv = nn.Conv2d(in_planes, out_planes,
                              kernel_size=kernel_size, stride=stride,
                              padding=padding, dilation=dilation, bias=False)
        self.bn = nn.BatchNorm2d(out_planes)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)
        return x

class MAFF(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv2 = ConvBN(256, 128, 1)
        self.conv3 = ConvBN(512, 128, 1)
        self.conv4 = ConvBN(1024, 128, 1)
        self.aff2 = Affinity(128)
        self.aff3 = Affinity(128)
        self.aff4 = Affinity(128)
        self.fusion_conv = ConvBNReLU(128 * 4, 256, 1)
        self.cls_seg = nn.Conv2d(256, 1, 1)

    def forward(self, x1, x2, x3, x4):
        x2 = F.interpolate(x2, size=x1.shape[2:], mode='bilinear', align_corners=False)
        x3 = F.interpolate(x3, size=x1.shape[2:], mode='bilinear', align_corners=False)
        x4 = F.interpolate(x4, size=x1.shape[2:], mode='bilinear', align_corners=False)
        x2 = self.conv2(x2)
        w2 = self.aff2(x1, x2)
        x2 = x2 * w2
        x3 = self.conv3(x3)
        w3 = self.aff3(x1, x3)
        x3 = x3 * w3
        x4 = self.conv4(x4)
        w4 = self.aff4(x1, x4)
        x4 = x4 * w4
        out = self.fusion_conv(torch.cat([x1, x2, x3, x4], dim=1))
        out = self.cls_seg(out)
        return out
    
if __name__ == '__main__':
    x1 = torch.zeros([1, 128, 88, 88])
    x2 = torch.zeros([1, 256, 44, 44])
    x3 = torch.zeros([1, 512, 22, 22])
    x4 = torch.zeros([1, 1024, 11, 11])
    net = MAFF()
    out = net(x1, x2, x3, x4)  
    flops, params = profile(net, inputs=(x1, x2, x3, x4,))
    print(out.shape)  # 1, 1, 88, 88
    print('FLOPs = ' + str(round(flops/1000**3, 2)) + 'G')  # 2.82G
    print('Params = ' + str(round(params/1000**2, 2)) + 'M')  # 0.36M
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值