[论文阅读] Adaptive Context Selection for Polyp Segmentation

论文地址:https://doi.org/10.1007/978-3-030-59725-2_25
代码:https://github.com/ReaFly/ACSNet
发表于:MICCAI’20

Abstract

准确的息肉分割对结直肠癌的诊断和治疗具有重要意义。然而,由于息肉的形状和大小各异,它一直是非常具有挑战性的。近年来,在深度卷积神经网络的帮助下,最先进的方法在这项任务中取得了重大突破。然而,很少有算法明确考虑到息肉的大小和形状以及复杂的空间环境对分割性能的影响,这导致算法对复杂样本仍然无能为力。事实上,不同大小的息肉的分割依赖于不同的局部和全局背景信息进行区域对比推理。为了解决这些问题,我们提出了一个基于自适应上下文选择的编码器-解码器框架,它由局部上下文注意(LCA)模块、全局上下文模块(GCM)和自适应选择模块(ASM)组成。具体来说,LCA模块将本地语境特征从编码器层传递到解码器层,加强对由上一层预测图决定的困难区域的关注。GCM的目的是进一步探索全局的上下文特征,并将其发送到解码层。ASM用于自适应选择,并通过通道式关注聚合上下文特征。我们提出的方法在EndoScene和Kvasir-SEG数据集上进行了评估,与其他最先进的方法相比,显示了出色的性能。

I. Architecture

在这里插入图片描述
整体走的还是魔改UNet那一套(即图中的E-Block1至D-Block1),只不过在中间加了一些上文提出的三大结构。主要的点:

  • 前四个E-Block的输出扔到LCA去提取一些局部上下文信息,提升细节处效果,同时进行“skip connection”(注意这么说并不严谨,只是从设计上类似)
  • 最后一个E-Block的输出扔到GCM里(一般这种全局信息提取模块都是放在Encoder和Decoder的连接处),提取些全局信息
  • 除最后一个D-Block,各D-Block对non local、LCA、GCM的信息进行学习选择
  • 各decoder层的输出都被相应下采样的ground truth进行监督
II. LCA(Local Context Attention)

在这里插入图片描述
所谓“局部”即对更容易出错的复杂区域(如边缘)进行学习。对于某个中间Encoder层的输出Prediction,将其放入一个公式计算:
A t t i j = 1 − ∣ p i + 1 j − T ∣ max ⁡ ( T , 1 − T ) A t t_{i}^{j}=1-\frac{\left|p_{i+1}^{j}-T\right|}{\max (T, 1-T)} Attij=1max(T,1T) pi+1jT
A t t A t t Att表示注意力图, p p p表示上层Encoder的输出, T T T代表用来区分某个特定像素是属于前景还是背景的阈值,文中使用0.5。

这个公式说成人话就是,如果某个像素的分类十分确定(比如就是1,前景;就是0,背景),那么我们就没有必要对其进行细化,相应的 A t t A t t Att值为0;而某个像素的分类很模糊(比如就是0.5,不确定是前景还是背景),那么相应的 A t t A t t Att值为1,表示网络需要重点关注这一部分。

而学习的方式就是将这个注意力图与原encoder输出进行相加,传到下一层继续学习。

代码长这样:

class LCA(nn.Module):
    def __init__(self):
        super(LCA, self).__init__()
    def forward(self, x, pred):
        residual = x
        score = torch.sigmoid(pred)
        dist = torch.abs(score - 0.5)
        att = 1 - (dist / 0.5)
        att_x = x * att
        out = att_x + residual
        return out
III. GCM(Global Context Module)

在这里插入图片描述
所谓“全局”就是提取高层次的特征了,或者是说“综合不同尺寸的特征”。

但是真要论提取全局特征的话,现在已经有一些比较成熟的结构了,比如ASPP,ASPP就是拿几个不同尺寸的卷积核取不同尺度下的特征然后合起来。这个也差不多,比ASPP还简单些,思想类似。

代码长这样:

class GCM(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(GCM, self).__init__()
        pool_size = [1, 3, 5]
        out_channel_list = [256, 128, 64, 64]
        upsampe_scale = [2, 4, 8, 16]
        GClist = []
        GCoutlist = []
        for ps in pool_size:
            GClist.append(nn.Sequential(
                nn.AdaptiveAvgPool2d(ps),
                nn.Conv2d(in_channels, out_channels, 1, 1),
                nn.ReLU(inplace=True)))
        GClist.append(nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 1, 1),
            nn.ReLU(inplace=True),
            NonLocalBlock(out_channels)))
        self.GCmodule = nn.ModuleList(GClist)
        for i in range(4):
            GCoutlist.append(nn.Sequential(nn.Conv2d(out_channels * 4, out_channel_list[i], 3, 1, 1),
                                           nn.ReLU(inplace=True),
                                           nn.Upsample(scale_factor=upsampe_scale[i], mode='bilinear')))
        self.GCoutmodel = nn.ModuleList(GCoutlist)

    def forward(self, x):
        xsize = x.size()[2:]
        global_context = []
        for i in range(len(self.GCmodule) - 1):
            global_context.append(F.interpolate(self.GCmodule[i](x), xsize, mode='bilinear', align_corners=True))
        global_context.append(self.GCmodule[-1](x))
        global_context = torch.cat(global_context, dim=1)
        output = []
        for i in range(len(self.GCoutmodel)):
            output.append(self.GCoutmodel[i](global_context))
        return output
IV. Adaptive Selection Module(ASP)

在这里插入图片描述
那么既然叫自适应选择,就来看看自适应了什么,又选择了什么。

所谓选择,就是我们拥有了不同层次的信息(non local,局部上下文,全局上下文),但是对于不同的case,这三者的重要性不一定相同,有些case可能需要更多的关注局部上下文,而有的case需要更多的关注全局上下文。

具体到实现,首先将这三个层次的信息按通道连接,然后走个Squeeze and Excite结构。这里的Non local和Squeeze and Excite也算是比较经典的结构了。

代码长这样:

class ASM(nn.Module):
    def __init__(self, in_channels, all_channels):
        super(ASM, self).__init__()
        self.non_local = NonLocalBlock(in_channels)
        self.selayer = SELayer(all_channels)
    def forward(self, lc, fuse, gc):
        fuse = self.non_local(fuse)
        fuse = torch.cat([lc, fuse, gc], dim=1)
        fuse = self.selayer(fuse)
        return fuse
  • 6
    点赞
  • 20
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值