1.前言
最近在看医学图像处理相关的论文,本论文是师兄推荐的。在阅读该论文的时候,鉴于网上还没人发阅读笔记,因此,萌生了发布该篇博文的想法。
PS:本人水平有限,解释不周之处还望海涵~。
2.摘要
反馈注意力网络(FANet)是一种用于生物医学图像分割的新颖架构。它利用每个训练时期的预测图信息来修剪后续时期的预测图,并使用前一时期掩模与当前训练时期的特征图相结合的方式提供硬注意力。该网络还允许在测试时间迭代地纠正预测结果。实验表明,FANet在七个公开可用的生物医学成像数据集上表现出色,证明了其有效性。
3.本文主要贡献
- 1.反馈注意力学习:一种利用每个训练样本中存在的可变性的新机制,将掩模输出从一个epoch传播到下一个epoch,以抑制不需要的特征杂波。
- 2.预测掩码的迭代优化:使用反馈信息有助于在训练和推理中优化预测掩码。
- 3.嵌入式游程编码策略:每个样本的二进制掩码输出在传播到下一个epoch之前被有效压缩。
- 4.在七个医学数据集上的实验表明,FANet优于其他最先进(SOTA)算法。
- 5.FANet以更少的训练次数达到接近SOTA的性能。
4.网络结构
该篇论文的网络结构可以分为三个模块:SE残差块、MixPool块和完整的网络结构。
4.1 SE残差块
4.1.1 结构
该网络结构图如图所示,主要可以分为三个部分:
作者在SE残差块中使用两个3×3卷积和一个恒等映射,其中每个卷积层后面是一个批归一化层和一个ReLU激活函数。
作者在残差网络中添加了一个SE层,SE层充当内容感知机制,其相应地重新加权每个通道以创建鲁棒的表示。
4.1.2 代码
该篇论文的代码写的比较清晰,有很多可以学习的地方。
该层的代码如下(可对照上文的网络结构学习):
class ResidualBlock(nn.Module):
def __init__(self, in_c, out_c):
super(ResidualBlock, self).__init__()
self.conv1 = nn.Conv2d(in_c, out_c, kernel_size=3, padding=1)
self.bn1 = nn.BatchNorm2d(out_c)
self.conv2 = nn.Conv2d(out_c, out_c, kernel_size=3, padding=1)
self.bn2 = nn.BatchNorm2d(out_c)
self.conv3 = nn.Conv2d(in_c, out_c, kernel_size=1, padding=0)
self.bn3 = nn.BatchNorm2d(out_c)
self.se = SELayer(out_c, out_c)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
# 第1部分
x1 = self.conv1(x)
x1 = self.bn1(x1)
x1 = self.relu(x1)
# 第2部分
x2 = self.conv2(x1)
x2 = self.bn2(x2)
x2 = self.se(x2)
# 第3部分
x3 = self.conv3(x)
x3 = self.bn3(x3)
# 正常输出和跳跃连接相加
x4 = x2 + x3
x4 = self.relu(x4)
return x4
4.1.3 SE模块
该模块是SE残差块中的Squeeze&Excite块。
图源与推荐视频:SE 模块
具体步骤:1.压缩:通过一个全局pooling,把特征图变为1×1的大小。
2.激励:对1×1的特征图做交替使用一个全连接层和一个激活函数
3.进行一个Scale操作
该模块的代码如下:
class SELayer(nn.Module):
def __init__(self, channel, reduction=16):
super(SELayer, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
# 激励
self.fc = nn.Sequential(
nn.Linear(channel, channel // reduction, bias=False),
nn.ReLU(inplace=True),
nn.Linear(channel // reduction, channel, bias=False),
nn.Sigmoid()
)
def forward(self, x):
b, c, _, _ = x.size()
# 步骤1
y = self.avg_pool(x).view(b, c)
# 步骤2
y = self.fc(y).view(b, c, 1, 1)
# 步骤3
return x * y.expand_as(x)
4.2 MixPool块
4.2.1 结构
该网络结构图如图所示,主要可以分为三个部分:
这个模块有助于网络在学习过程中记忆和应用历史时期的信息,以实现更有效的训练和预测能力。同时,还为SE残差块中学到的特征提供了硬注意力。
① MixPool块是接在SE残差块之后的,因此SE残差块传入的Fl特征通过3×3卷积,BN层和Relu激活函数,再进行一个1×1卷积和Sigmoid激活函数,最后使用一个阈值为0.5的激活函数来获得二进制掩码Ml’(硬注意力,大于0.5为1,小于等于0.5为0)。
② 接着对先前时期的输入掩码应用最大池化,再与二进制掩码进行联合操作。然后与原始特征图进行逐元素乘法运算,该运算抑制不相关的特征并增强重要的特征。
③ 最后将增强图与原始图分别进行3×3卷积、BN层和Relu,并将两个激活函数的输出连接起来。
4.2.2 代码
该层的代码如下:
class MixPool(nn.Module):
def __init__(self, in_c, out_c):
super(MixPool, self).__init__()
self.fmask = nn.Sequential(
nn.Conv2d(in_c, out_c, kernel_size=3, padding=1),
nn.BatchNorm2d(out_c),
nn.ReLU(inplace=True),
nn.Conv2d(out_c, 1, kernel_size=1, padding=0),
nn.Sigmoid()
)
self.conv1 = nn.Sequential(
nn.Conv2d(in_c, out_c//2, kernel_size=3, padding=1),
nn.BatchNorm2d(out_c//2),
nn.ReLU(inplace=True)
)
self.conv2 = nn.Sequential(
nn.Conv2d(in_c, out_c//2, kernel_size=3, padding=1),
nn.BatchNorm2d(out_c//2),
nn.ReLU(inplace=True)
)
def forward(self, x, m):
# 第1部分
fmask = (self.fmask(x) > 0.5).type(torch.cuda.FloatTensor)
# 第2部分
m = nn.MaxPool2d((m.shape[2]//x.shape[2], m.shape[3]//x.shape[3]))(m)
x1 = x * torch.logical_or(fmask, m).type(torch.cuda.FloatTensor)
# 第3部分
x1 = self.conv1(x1)
x2 = self.conv2(x)
x = torch.cat([x1, x2], axis=1)
return x
4.3 完整的网络结构
该网络结构是一个编码器-解码器结构,由四个编码器和四个解码器组成。作者实现了一个循环学习机制,网络结构如图所示:其中下面四个为编码器,上面四个为解码器。
可以看出该网络中两个SE残差块后立马接了一个MixPool块。
从Input Mask开始看,MixPool块使用先前的分割图,将RLE编码作为输入掩码(上一个epoch的结果)。其包含来自先前训练的信息并用于改进特征图的语义表示。
经过四个编码器后,进入解码器的部分。Transpose Convolution是转置卷积。Concatenate是连接,将编码器通过跳跃连接传过来的信息和解码器的信息相连。编码器获取输入图像,逐渐对其进行下采样,并将其编码为紧凑的表示。然后,解码器采用这种紧凑的表示,并尝试通过逐渐上采样并结合来自编码器的特征来重建语义表示。第二个SE残差块的输出充当对应解码器块的跳跃连接。
将来自最后一个解码器块的特征图和来自前一个epoch的分割掩码连接起来。最后,使用sigmoid激活函数进行1×1卷积。其输出用于使用组合的二进制交叉熵和dice损失来最小化训练损失,并生成分割掩码。
4.3.1 编码器代码实现
编码器结构如图所示(从左往右传播):
代码如下:
class EncoderBlock(nn.Module):
def __init__(self, in_c, out_c, name=None):
super(EncoderBlock, self).__init__()
self.name = name
self.r1 = ResidualBlock(in_c, out_c)
self.r2 = ResidualBlock(out_c, out_c)
self.p1 = MixPool(out_c, out_c)
self.pool = nn.MaxPool2d((2, 2))
def forward(self, inputs, masks):
x = self.r1(inputs)
x = self.r2(x)
p = self.p1(x, masks)
o = self.pool(p)
return o, x
代码结构整体比较清晰,其中,o表示输出结果,x表示跳跃连接的信息。
4.3.2 解码器代码实现
解码器结构如图所示(从右往左传播):
代码如下:
class DecoderBlock(nn.Module):
def __init__(self, in_c, out_c, name=None):
super(DecoderBlock, self).__init__()
self.upsample = nn.ConvTranspose2d(in_c, in_c, kernel_size=4, stride=2, padding=1)
self.r1 = ResidualBlock(in_c+in_c, out_c)
self.r2 = ResidualBlock(out_c, out_c)
self.p1 = MixPool(out_c, out_c)
def forward(self, inputs, skip, masks):
x = self.upsample(inputs)
x = torch.cat([x, skip], axis=1)
x = self.r1(x)
x = self.r2(x)
p = self.p1(x, masks)
return p
4.3.3 完整的网络代码
class FANet(nn.Module):
def __init__(self):
super(FANet, self).__init__()
self.e1 = EncoderBlock(3, 32)
self.e2 = EncoderBlock(32, 64)
self.e3 = EncoderBlock(64, 128)
self.e4 = EncoderBlock(128, 256)
self.d1 = DecoderBlock(256, 128)
self.d2 = DecoderBlock(128, 64)
self.d3 = DecoderBlock(64, 32)
self.d4 = DecoderBlock(32, 16)
self.output = nn.Conv2d(16+1, 1, kernel_size=1, padding=0)
def forward(self, x):
inputs, masks = x[0], x[1]
p1, s1 = self.e1(inputs, masks)
p2, s2 = self.e2(p1, masks)
p3, s3 = self.e3(p2, masks)
p4, s4 = self.e4(p3, masks)
d1 = self.d1(p4, s4, masks)
d2 = self.d2(d1, s3, masks)
d3 = self.d3(d2, s2, masks)
d4 = self.d4(d3, s1, masks)
d5 = torch.cat([d4, masks], axis=1)
output = self.output(d5)
return output
5.结果
FANet在以下七个数据集上都取得了优异的效果,具体可看原论文~