这个网络结构是在YoloV5里面使用到比较有趣的网络结构,具体操作是在一张图片中每隔一个像素拿到一个值,这个时候获得了四个独立的特征层,然后将四个独立的特征层进行堆叠,此时宽高信息就集中到了通道信息,输入通道扩充了四倍。拼接起来的特征层相对于原先的三通道变成了十二个通道。
python代码实现
class Focus(nn.Module):
def __init__(self, in_channels, out_channels, ksize=1, stride=1, act="silu"):
super(Focus, self).__init__()
self.conv = BaseConv(in_channels * 4, out_channels, ksize, stride, act=act)
def forward(self, x):
# 图像被分成四块
# 行和列均按2叠加 对应图片中1
patch_top_left = x[..., ::2, ::2]
# 对应图片中3
patch_bot_left = x[..., 1::2, ::2]
# 对应图片中2
patch_top_right = x[..., ::2, 1::2]
# 对应图片中4
patch_bot_right = x[..., 1::2, 1::2]
x = torch.cat((patch_top_left, patch_bot_left, patch_top_right, patch_bot_right), dim=1)
return self.conv(x)