Abstract
手术场景分割是机器人辅助手术中的一项关键任务。然而,手术场景的复杂性,主要包括局部特征相似性(例如,在不同解剖组织之间)、术中复杂的伪影和难以区分的边界,对准确分割构成了重大挑战。为了解决这些问题,我们提出了长条核注意力网络(LSKANet),包括两个设计良好的模块,分别是双块大核注意力模块(DLKA)和多尺度亲和特征融合模块(MAFF),可以实现手术图像的精确分割。具体来说,通过在两个块中引入具有不同拓扑结构(级联和并行)的条形卷积和大内核设计,DLKA可以充分利用区域和条状手术特征,并提取视觉和结构信息,以减少由局部特征相似性引起的错误分割。在MAFF中,从多尺度特征图计算的亲和矩阵作为特征融合权重应用,这有助于通过抑制不相关区域的激活来解决伪影的干扰。此外,该文提出边界引导头(BGH)的混合损失,以帮助网络有效地分割难以区分的边界。我们在具有不同手术场景的三个数据集上评估了所提出的LSKANet。实验结果表明,我们的方法在所有三个数据集上都取得了最新的结果,mIoU分别提高了2.6%、1.4%和3.4%。此外,我们的方法与不同的backbone兼容,可以显著提高它们的分割精度。
I. Overview
II. Dual-block Large Kernel Attention
DLKA由两个子Block构成,分别为Anatomy Block和Instrument Block。两种block的共有特点是都用条形卷积(Strip Convolution)来实现高效大卷积核设计,区别在于Anatomy Block的条形卷积采用串行设计(先1×W再H×1),而Instrument Block的条形卷积采用并行设计(同时1×W,H×1再相加)。
Q&A
Q1: 大卷积核优缺点。
A1: 优点,扩大感受野,捕捉长距离依赖;缺点,容易引入不相关区域中的噪声。
Q2: 为什么Anatomy Block是串行而Instrument Block是并行。
A2: 串行建模的是矩形区域,并行建模的仍是条状区域,与内镜组织和手术器械的客观形状相符。
Abla Study
将DLKA模块替换为标准卷积后, 完整模型参数由72.26M降低至72.25M, mIoU由69.6降低至68.8。注意原文并未说明这里所述的"标准卷积"的尺寸。
移除整个DLKA模块,mIoU由69.6降低至68.0。
移除DLKA中的Anatomy Block,mIoU由69.6降低至69.5。这里mIoU只降了0.1是因为Anatomy类性能降低的同时Instrument类的性能有所上升。
移除DLKA中的Instrument Block,mIoU由69.6降低至69.0。
Code-AnaBlock
import torch
import torch.nn as nn
from thop import profile
class AnaBlock(nn.Module):
def __init__(self, dim):
super().__init__()
self.conv0 = nn.Conv2d(dim, dim, 5, padding=2, groups=dim)
self.conv1_1 = nn.Conv2d(dim, dim, (1, 11), padding=(0, 5), groups=dim)
self.conv1_2 = nn.Conv2d(dim, dim, (11, 1), padding=(5, 0), groups=dim)
self.conv2_1 = nn.Conv2d(dim, dim, (1, 21), padding=(0, 10), groups=dim)
self.conv2_2 = nn.Conv2d(dim, dim, (21, 1), padding=(10, 0), groups=dim)
self.conv3_1 = nn.Conv2d(dim, dim, (1, 31), padding=(0, 15), groups=dim)
self.conv3_2 = nn.Conv2d(dim, dim, (31, 1), padding=(15, 0), groups=dim)
self.conv3 = nn.Conv2d(dim, dim, 1)
def forward(self, x):
u = x.clone()
attn = self.conv0(x)
attn_1 = self.conv1_1(attn)
attn_1 = self.conv1_2(attn_1)
attn_2 = self.conv2_1(attn)
attn_2 = self.conv2_2(attn_2)
attn_3 = self.conv3_1(attn)
attn_3 = self.conv3_2(attn_3)
attn = attn + attn_1 + attn_2 + attn_3
attn = self.conv3(attn)
return attn * u
if __name__ == '__main__':
net = AnaBlock(dim=512)
x = torch.randn([1, 512, 22, 22])
out = net(x)
flops, params = profile(net, inputs=(x, ))
print(out.shape) # 1, 512, 22, 22
print('FLOPs = ' + str(round(flops/1000**3, 2)) + 'G') # 0.16G
print('Params = ' + str(round(params/1000**2, 2)) + 'M') # 0.34M
Code-InsBlock
import torch
import torch.nn as nn
from thop import profile
class InsBlock(nn.Module):
def __init__(self, dim):
super().__init__()
self.conv0 = nn.Conv2d(dim, dim, 5, padding=2, groups=dim)
self.conv1_1 = nn.Conv2d(dim, dim, (1, 11), padding=(0, 5), groups=dim)
self.conv1_2 = nn.Conv2d(dim, dim, (11, 1), padding=(5, 0), groups=dim)
self.conv2_1 = nn.Conv2d(dim, dim, (1, 21), padding=(0, 10), groups=dim)
self.conv2_2 = nn.Conv2d(dim, dim, (21, 1), padding=(10, 0), groups=dim)
self.conv3_1 = nn.Conv2d(dim, dim, (1, 31), padding=(0, 15), groups=dim)
self.conv3_2 = nn.Conv2d(dim, dim, (31, 1), padding=(15, 0), groups=dim)
self.conv3 = nn.Conv2d(dim, dim, 1)
def forward(self, x):
u = x.clone()
attn = self.conv0(x)
attn_1_1 = self.conv1_1(attn)
attn_1_2 = self.conv1_2(attn)
attn_1 = attn_1_1 + attn_1_2
attn_2_1 = self.conv2_1(attn)
attn_2_2 = self.conv2_2(attn)
attn_2 = attn_2_1 + attn_2_2
attn_3_1 = self.conv3_1(attn)
attn_3_2 = self.conv3_2(attn)
attn_3 = attn_3_1 + attn_3_2
attn = attn + attn_1 + attn_2 + attn_3
attn = self.conv3(attn)
return attn * u
if __name__ == '__main__':
net = InsBlock(dim=512)
x = torch.randn([1, 512, 22, 22])
out = net(x)
flops, params = profile(net, inputs=(x, ))
print(out.shape) # 1, 512, 22, 22
print('FLOPs = ' + str(round(flops/1000**3, 2)) + 'G') # 0.16G
print('Params = ' + str(round(params/1000**2, 2)) + 'M') # 0.34M
Code-DLKA
import torch
import torch.nn as nn
from thop import profile
class DLKA(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
self.proj_1 = nn.Conv2d(dim, dim, 1)
self.activation = nn.ReLU()
self.ins_block = InsBlock(dim)
self.ana_block = AnaBlock(dim)
self.proj_2 = nn.Conv2d(dim, dim, 1)
def forward(self, x):
shorcut = x.clone()
x = self.proj_1(x)
x = self.activation(x)
x_ins = self.ins_block(x)
x_ins = self.proj_2(x_ins)
x_ana = self.ana_block(x)
x_ana = self.proj_2(x_ana)
x = shorcut + x_ins + x_ana
return x
if __name__ == '__main__':
net = DLKA(dim=512)
x = torch.randn([1, 512, 22, 22])
out = net(x)
flops, params = profile(net, inputs=(x, ))
print(out.shape) # 1, 512, 22, 22
print('FLOPs = ' + str(round(flops/1000**3, 2)) + 'G') # 0.71G
print('Params = ' + str(round(params/1000**2, 2)) + 'M') # 1.21M
III. Multiscale Affinity Feature Fusion
MAFF作为LSKANet的解码器,主要任务是实现更好的多级特征融合。具体来说,浅层特征包含较多的伪影噪声,但拥有更多局部细节;深层特征噪声较少,但是缺失细节。直接将深层特征上采样与浅层特征融合会引发特征对齐问题。本文的做法是引入特征亲和度(Feature Affinity)来评估同一个空间位置下,不同特征图内容的相似程度。如果相似程度低,则说明该区域更高概率存在噪声,应该进行抑制;反之则应该进行增强。从这个角度看,本文的Feature Affinity实际就是一种空间注意力。
Abla Study
将MAFF中的Feature Affinity设计移除, 模型参数几乎不变, mIoU由69.1降低至68.6。
Code-Affinity
import torch
import torch.nn as nn
from thop import profile
class ConvBN(nn.Module):
def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1):
super().__init__()
self.conv = nn.Conv2d(in_planes, out_planes,
kernel_size=kernel_size, stride=stride,
padding=padding, dilation=dilation, bias=False)
self.bn = nn.BatchNorm2d(out_planes)
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
return x
class Affinity(nn.Module):
def __init__(self, dim):
super().__init__()
self.Sigmoid = nn.Sigmoid()
self.relu = nn.ReLU(inplace=True)
self.conv = ConvBN(dim, 1, kernel_size=1, stride=1, padding=0)
def forward(self, x1, x2):
aff_ = self.relu(x1 + x2)
aff = self.conv(aff_)
aff = self.Sigmoid(aff)
return aff
if __name__ == '__main__':
net = Affinity(dim=512)
x1 = torch.randn([1, 512, 22, 22])
x2 = torch.randn([1, 512, 22, 22])
out = net(x1, x2)
flops, params = profile(net, inputs=(x1, x2,))
print(out.shape) # 1, 1, 22, 22
print('FLOPs = ' + str(round(flops/1000**3, 2)) + 'G') # ~0.00G
print('Params = ' + str(round(params/1000**2, 2)) + 'M') # ~0.00M
Code-MAFF
import torch
import torch.nn as nn
import torch.nn.functional as F
from thop import profile
class ConvBN(nn.Module):
def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1):
super().__init__()
self.conv = nn.Conv2d(in_planes, out_planes,
kernel_size=kernel_size, stride=stride,
padding=padding, dilation=dilation, bias=False)
self.bn = nn.BatchNorm2d(out_planes)
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
return x
class ConvBNReLU(nn.Module):
def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1):
super().__init__()
self.conv = nn.Conv2d(in_planes, out_planes,
kernel_size=kernel_size, stride=stride,
padding=padding, dilation=dilation, bias=False)
self.bn = nn.BatchNorm2d(out_planes)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
x = self.relu(x)
return x
class MAFF(nn.Module):
def __init__(self):
super().__init__()
self.conv2 = ConvBN(256, 128, 1)
self.conv3 = ConvBN(512, 128, 1)
self.conv4 = ConvBN(1024, 128, 1)
self.aff2 = Affinity(128)
self.aff3 = Affinity(128)
self.aff4 = Affinity(128)
self.fusion_conv = ConvBNReLU(128 * 4, 256, 1)
self.cls_seg = nn.Conv2d(256, 1, 1)
def forward(self, x1, x2, x3, x4):
x2 = F.interpolate(x2, size=x1.shape[2:], mode='bilinear', align_corners=False)
x3 = F.interpolate(x3, size=x1.shape[2:], mode='bilinear', align_corners=False)
x4 = F.interpolate(x4, size=x1.shape[2:], mode='bilinear', align_corners=False)
x2 = self.conv2(x2)
w2 = self.aff2(x1, x2)
x2 = x2 * w2
x3 = self.conv3(x3)
w3 = self.aff3(x1, x3)
x3 = x3 * w3
x4 = self.conv4(x4)
w4 = self.aff4(x1, x4)
x4 = x4 * w4
out = self.fusion_conv(torch.cat([x1, x2, x3, x4], dim=1))
out = self.cls_seg(out)
return out
if __name__ == '__main__':
x1 = torch.zeros([1, 128, 88, 88])
x2 = torch.zeros([1, 256, 44, 44])
x3 = torch.zeros([1, 512, 22, 22])
x4 = torch.zeros([1, 1024, 11, 11])
net = MAFF()
out = net(x1, x2, x3, x4)
flops, params = profile(net, inputs=(x1, x2, x3, x4,))
print(out.shape) # 1, 1, 88, 88
print('FLOPs = ' + str(round(flops/1000**3, 2)) + 'G') # 2.82G
print('Params = ' + str(round(params/1000**2, 2)) + 'M') # 0.36M