1 概述
本文注重解决红外图像距离远、分辨率低的问题。具体做法是在设计特征提取网络时,对浅层的跨阶段部分连接(CSP)模块进行扩展和迭代,最大限度地利用浅层特征;在残差块中引入改进的注意力模块,实现对目标的聚焦和背景的抑制;另外还改进了网络检测头结构,增加了多尺度目标检测层,采用四层空间金字塔池增加接收野,提高了小目标的检测精度。
2 结构
下图为该论文提出的网络结构:
从结构来看,改进点有以下几个:
①在骨干网中,利用扩展的CSP模块对浅层和深层特征图进行丰富信息的提取,CSP模块中引入的注意力机制引导不同权值的分配,实现对弱特征和小特征的提取,扩展的CSP模块结构如下:
②SPP层通过四个池化窗口将通道维度上获得的结果进行拼接,解决锚点与特征图的对齐问题,随后使用SK注意模块来增强提取的特征,SK模块结构如下:
③在颈部网络中,采用PANet生成特征金字塔,采用自顶向下和自底向上的融合结构,有效融合从骨干网络中提取的多尺度特征,增强对不同尺度目标的检测。
本文解决红外图像问题的思路值得借鉴,采用的模块并不是很新颖,并且所改进的csp模块经过单步调试发现会有很多多余的结构,这部分结构输出的值都为0,模块的作用有待商榷。
3 复现
这篇论文没有开源代码,我将其中几个关键模块复现了一下:
3.1 Extended esp block:
class ExtendedCSP(nn.Module):
def __init__(self, in_channels, out_channels):
super(ExtendedCSP, self).__init__()
self.split = nn.Conv2d(in_channels, in_channels // 2, kernel_size=1)
self.conv1 = nn.Conv2d(in_channels // 2, in_channels // 2, kernel_size=3, padding=1)
self.bn1 = nn.BatchNorm2d(in_channels // 2)
self.relu1 = nn.ReLU()
self.sk = SKconv(in_channels // 2, in_channels // 2)
self.conv2 = nn.Conv2d(in_channels // 2, in_channels // 2, kernel_size=3, padding=1)
self.bn2 = nn.BatchNorm2d(in_channels // 2)
self.relu2 = nn.ReLU()
self.concat = nn.Conv2d(in_channels, out_channels, kernel_size=1)
def forward(self, x):
x1 = self.split(x)
x2 = self.split(x)
x2 = self.relu1(self.bn1(self.conv1(x2)))
x2 = self.sk(x2)
x2 = self.relu2(self.bn2(self.conv2(x2)))
x = torch.cat([x1, x2], dim=1)
x = self.concat(x)
return x
3.2 SK_attention:
class SKConv(nn.Module):
def __init__(self, in_ch, out_ch, bias=False):
super(SKConv, self).__init__()
self.conv3x3 = nn.Conv2d(in_ch, out_ch, kernel_size=3, stride=1, padding=1, bias=bias)
self.conv5x5 = nn.Conv2d(in_ch, out_ch, kernel_size=5, stride=1, padding=2, bias=bias)
self.global_pool = nn.AdaptiveAvgPool2d(1)
self.fc1 = nn.Linear(out_ch, out_ch // 16)
self.fc2 = nn.Linear(out_ch // 16, 2 * out_ch)
self.softmax = nn.Softmax(dim=1)
def forward(self, x):
u1 = self.conv3x3(x)
u2 = self.conv5x5(x)
u = u1 + u2
s = self.global_pool(u)
s = s.view(s.size(0), -1)
s = self.fc1(s)
s = F.relu(s)
s = self.fc2(s)
s = self.softmax(s)
s = s.view(s.size(0), 2, s.size(1) // 2, 1, 1)
u1 = u1 * s[:, 0]
u2 = u2 * s[:, 1]
u = u1 + u2
return u