YOLOX所使用的主干特征提取网络为CSPDarknet,如下图左侧框所示。
图片来源: Pytorch 搭建自己的YoloX目标检测平台(Bubbliiiing 深度学习 教程)_哔哩哔哩_bilibili
CSPDarknet的几个要点总结如下。
1. Focus网络结构
Focus结构的具体操作是,在一幅图像中行和列的方向进行隔像素抽取,组成新的特征层,每幅图像可重组为4个特征层,然后将4个特征层进行堆叠,将输入通道扩展为4倍。堆叠后的特征层相对于原先的3通道变为12通道,如下图所示:
PyTorch代码实现如下:
class Focus(nn.Module):
"""Focus width and height information into channel space."""
def __init__(self, in_channels, out_channels, ksize=1, stride=1, act="silu"):
super().__init__()
self.conv = BaseConv(in_channels * 4, out_channels, ksize, stride, act=act)
def forward(self, x):
# shape of x (b,c,w,h) -> y(b,4c,w/2,h/2)
patch_top_left = x[..., ::2, ::2]
patch_top_right = x[..., ::2, 1::2]
patch_bot_left = x[..., 1::2, ::2]
patch_bot_right = x[..., 1::2, 1::2]
x = torch.cat(
(
patch_top_left,
patch_bot_left,
patch_top_right,
patch_bot_right,
),
dim=1,
)
return self.conv(x)
2. 残差网络Residual
CSPDarknet中的残差网络分为两个分支,主干分支做一次1x1卷积和一次3x3卷积,残差边部分不做任何处理,相当于直接将主干的输入和输出结合。
代码如下,
class Bottleneck(nn.Module):
# Standard bottleneck
def __init__(
self,
in_channels,
out_channels,
shortcut=True,
expansion=0.5,
depthwise=False,
act="silu",
):
super().__init__()
hidden_channe