代码来源
https://github.com/ReaFly/ACSNet
模块作用
LCA 模块将局部上下文特征从编码器层传递到解码器层,增强了对由前一层的预测图确定的硬区域的注意力。GCM 旨在进一步探索全局上下文特征并发送到解码器层。ASM 用于通过通道注意自适应地选择和聚合上下文特征。
模块结构
ACSNet基于UNet结构,包含编码器和解码器。编码器使用ResNet34提取特征,解码器生成分割图,每个解码块都受到下采样地面真相的监督。
- 功能:捕获细粒度的局部上下文信息,重点关注硬样本(难以分割的区域)。
- 实现:使用上一层解码器的预测图计算注意力图,公式为:
其中,P是预测图,T=0.5为阈值,H,W为注意力图的高度和宽度。特征通过注意力值加权后与原始特征相加,输出增强的局部上下文特征。
- 作用:替换UNet中的跳跃连接,提供不同感受野的局部上下文增强,实验显示在Kvasir-SEG上提高了Dice得分0.79%。
- 功能:捕获全局上下文信息,补偿解码器层级细化过程中可能丢失的全局特征。
- 结构:包含四个分支:
- 全局平均池化,捕获整体上下文。
- 两个自适应局部平均池化分支(3x3和5x5),捕获中尺度上下文。
- 身份映射分支,结合非局部操作,处理长距离依赖。
- 输出特征图包括1x1、3x3、5x5尺度和原始分辨率,特征上采样后拼接,密集馈送到每个解码器的ASM模块。
- 效果:在Kvasir-SEG上,加入GCM后Dice得分从89.00%提高到90.28%,提升1.28%。
- 功能:自适应选择和融合LCA、GCM以及上一解码块的特征,处理不同大小和形状的息肉。
- 实现:使用“挤压-激励”块进行通道级特征重新校准,通过全局平均池化将特征图压缩为向量,再通过全连接层学习通道权重,sigmoid操作限制权重在0-1范围内。非局部操作增强解码器特征的长距离依赖。
- 效果:在Kvasir-SEG上,加入ASM后Dice得分从90.28%提高到91.30%,提升1.02%。
代码
class ACSNet(nn.Module):
def __init__(self, num_classes):
super(ACSNet, self).__init__()
resnet = models.resnet34(pretrained=True)
# Encoder
self.encoder1_conv = resnet.conv1
self.encoder1_bn = resnet.bn1
self.encoder1_relu = resnet.relu
self.maxpool = resnet.maxpool
self.encoder2 = resnet.layer1
self.encoder3 = resnet.layer2
self.encoder4 = resnet.layer3
self.encoder5 = resnet.layer4
# Decoder
self.decoder5 = DecoderBlock(in_channels=512, out_channels=512)
self.decoder4 = DecoderBlock(in_channels=1024, out_channels=256)
self.decoder3 = DecoderBlock(in_channels=512, out_channels=128)
self.decoder2 = DecoderBlock(in_channels=256, out_channels=64)
self.decoder1 = DecoderBlock(in_channels=192, out_channels=64)
self.outconv = nn.Sequential(ConvBlock(64, 32, kernel_size=3, stride=1, padding=1),
nn.Dropout2d(0.1),
nn.Conv2d(32, num_classes, 1))
# Sideout
self.sideout2 = SideoutBlock(64, 1)
self.sideout3 = SideoutBlock(128, 1)
self.sideout4 = SideoutBlock(256, 1)
self.sideout5 = SideoutBlock(512, 1)
# local context attention module
self.lca1 = LCA()
self.lca2 = LCA()
self.lca3 = LCA()
self.lca4 = LCA()
# global context module
self.gcm = GCM(512, 64)
# adaptive selection module
self.asm4 = ASM(512, 1024)
self.asm3 = ASM(256, 512)
self.asm2 = ASM(128, 256)
self.asm1 = ASM(64, 192)
def forward(self, x):
# x 224
e1 = self.encoder1_conv(x) # 128
e1 = self.encoder1_bn(e1)
e1 = self.encoder1_relu(e1)
e1_pool = self.maxpool(e1) # 56
e2 = self.encoder2(e1_pool)
e3 = self.encoder3(e2) # 28
e4 = self.encoder4(e3) # 14
e5 = self.encoder5(e4) # 7
global_contexts = self.gcm(e5)
d5 = self.decoder5(e5) # 14
out5 = self.sideout5(d5)
lc4 = self.lca4(e4, out5)
gc4 = global_contexts[0]
comb4 = self.asm4(lc4, d5, gc4)
d4 = self.decoder4(comb4) # 28
out4 = self.sideout4(d4)
lc3 = self.lca3(e3, out4)
gc3 = global_contexts[1]
comb3 = self.asm3(lc3, d4, gc3)
d3 = self.decoder3(comb3) # 56
out3 = self.sideout3(d3)
lc2 = self.lca2(e2, out3)
gc2 = global_contexts[2]
comb2 = self.asm2(lc2, d3, gc2)
d2 = self.decoder2(comb2) # 128
out2 = self.sideout2(d2)
lc1 = self.lca1(e1, out2)
gc1 = global_contexts[3]
comb1 = self.asm1(lc1, d2, gc1)
d1 = self.decoder1(comb1) # 224*224*64
out1 = self.outconv(d1) # 224
return torch.sigmoid(out1), torch.sigmoid(out2), torch.sigmoid(out3), \
torch.sigmoid(out4), torch.sigmoid(out5)
总结
在本文中,我们认为高效地感知局部与全局上下文对于提高息肉区域定位与分割的性能至关重要。基于此,我们提出了一种基于自适应上下文选择的编码器-解码器框架,其中包含用于困难区域挖掘的局部上下文提取的 LCA 模块、用于每个解码器块中的全局特征表示和增强的 GCM 模块以及用于上下文信息聚合和选择的 ASM 组件。