论文地址:https://arxiv.org/abs/2210.02093
源码地址:GitHub - QY1994-0919/CFPNet: Centralized Feature Pyramid for Object Detection
主要思想:
视觉特征金字塔在许多应用中已经显示出其有效性和高效性的优势。然而,现有的方法过于关注层间特征交互,却忽视了层内特征调节,而实际上层内特征调节是非常有益的。虽然有些方法试图通过注意力机制或视觉转换器来学习紧凑的层内特征表示,但它们忽略了对于密集预测任务非常重要的边缘区域。为了解决这个问题,本文提出了一种用于目标检测的集中式特征金字塔(CFP),它基于全局显式的集中式特征调节。具体来说,我们首先提出了一种空间显式的视觉中心方案,其中使用轻量级的多层感知机(MLP)来捕获全局长距离依赖关系,并使用并行的可学习视觉中心机制来捕获输入图像的局部边缘区域。在此基础上,我们提出了一种全局集中式的调节方法,用于从上到下对常用的特征金字塔进行调节,其中从最深层的层内特征中获得的显式视觉中心信息用于调节前面的浅层特征。与现有的特征金字塔相比,CFP不仅能够捕获全局长距离依赖关系,还能够有效地获得全面且具有判别性的特征表示。在具有挑战性的MS-COCO数据集上的实验结果表明,我们提出的CFP可以在最先进的YOLOv5和YOLOX目标检测基准上实现一致的性能提升。
简单来说,我们提出了一种新的目标检测方法,它使用了一种叫做集中式特征金字塔(CFP)的技术。CFP不仅关注不同层之间的特征交互,还注重每一层内部的特征调整。通过这种方式,CFP能够更全面地提取图像的特征,尤其是那些容易被忽略的边缘区域。实验证明,这种方法在目标检测任务中取得了很好的效果,相比之前的方法有了明显的性能提升。
模型架构:
Pytorch版源码:
import torch
import torch.nn as nn
import torch.nn.functional as F
class DeConv(nn.Module):
def __init__(self, nIn, nOut, kSize, stride, padding, output_padding, dilation=(1, 1), groups=1, bn_acti=False, bias=False):
super().__init__()
self.bn_acti = bn_acti
self.conv = nn.ConvTranspose2d(nIn, nOut, kernel_size=kSize,
stride=stride, padding=padding, output_padding=output_padding,
dilation=dilation, groups=groups, bias=bias)
if self.bn_acti:
self.bn_prelu = BNPReLU(nOut)
def forward(self, input):
output = self.conv(input)
if self.bn_acti:
output = self.bn_prelu(output)
return output
class Conv(nn.Module):
def __init__(self, nIn, nOut, kSize, stride, padding, dilation=(1, 1), groups=1, bn_acti=False, bias=False):
super().__init__()
self.bn_acti = bn_acti
self.conv = nn.Conv2d(nIn, nOut, kernel_size=kSize,
stride=stride, padding=padding,
dilation=dilation, groups=groups, bias=bias)
if self.bn_acti:
self.bn_prelu = BNPReLU(nOut)
def forward(self, input):
output = self.conv(input)
if self.bn_acti:
output = self.bn_prelu(output)
return output
class BNPReLU(nn.Module):
def __init__(self, nIn):
super().__init__()
self.bn = nn.BatchNorm2d(nIn, eps=1e-3)
self.acti = nn.PReLU(nIn)
def forward(self, input):
output = self.bn(input)
output = self.acti(output)
return output
# down-sampling
class DownSamplingBlock(nn.Module):
"""
nIn:输入通道数
nOut:输出通道数
"""
def __init__(self, nIn, nOut):
super().__init__()
self.nIn = nIn
self.nOut = nOut
if self.nIn < self.nOut:
nConv = nOut - nIn
else:
nConv = nOut
self.conv3x3 = Conv(nIn, nConv, kSize=3, stride=2, padding=1)
self.max_pool = nn.MaxPool2d(2, stride=2)
self.bn_prelu = BNPReLU(nOut)
def forward(self, input):
output = self.conv3x3(input)
if self.nIn < self.nOut:
max_pool = self.max_pool(input)
output = torch.cat([output, max_pool], 1)
output = self.bn_prelu(output)
return output
# 执行三次平均池化:1/2,1/4,1/8
class InputInjection(nn.Module):
def __init__(self, ratio):
super().__init__()
self.pool = nn.ModuleList()
for i in range(0, ratio):
self.pool.append(nn.AvgPool2d(3, stride=2, padding=1))
def forward(self, input):
for pool in self.pool:
input = pool(input)
return input
class CFPModule(nn.Module):
def __init__(self, nIn, d=1, KSize=3, dkSize=3):
super().__init__()
self.bn_relu_1 = BNPReLU(nIn)
self.bn_relu_2 = BNPReLU(nIn)
self.conv1x1_1 = Conv(nIn, nIn // 4, KSize, 1, padding=1, bn_acti=True)
self.dconv3x1_4_1 = Conv(nIn // 4, nIn // 16, (dkSize, 1), 1,
padding=(1 * d + 1, 0), dilation=(d + 1, 1), groups=nIn // 16, bn_acti=True)
self.dconv1x3_4_1 = Conv(nIn // 16, nIn // 16, (1, dkSize), 1,
padding=(0, 1 * d + 1), dilation=(1, d + 1), groups=nIn // 16, bn_acti=True)
self.dconv3x1_4_2 = Conv(nIn // 16, nIn // 16, (dkSize, 1), 1,
padding=(1 * d + 1, 0), dilation=(d + 1, 1), groups=nIn // 16, bn_acti=True)
self.dconv1x3_4_2 = Conv(nIn // 16, nIn // 16, (1, dkSize), 1,
padding=(0, 1 * d + 1), dilation=(1, d + 1), groups=nIn // 16, bn_acti=True)
self.dconv3x1_4_3 = Conv(nIn // 16, nIn // 8, (dkSize, 1), 1,
padding=(1 * d + 1, 0), dilation=(d + 1, 1), groups=nIn // 16, bn_acti=True)
self.dconv1x3_4_3 = Conv(nIn // 8, nIn // 8, (1, dkSize), 1,
padding=(0, 1 * d + 1), dilation=(1, d + 1), groups=nIn // 8, bn_acti=True)
self.dconv3x1_1_1 = Conv(nIn // 4, nIn // 16, (dkSize, 1), 1,
padding=(1, 0), groups=nIn // 16, bn_acti=True)
self.dconv1x3_1_1 = Conv(nIn // 16, nIn // 16, (1, dkSize), 1,
padding=(0, 1), groups=nIn // 16, bn_acti=True)
self.dconv3x1_1_2 = Conv(nIn // 16, nIn // 16, (dkSize, 1), 1,
padding=(1, 0), groups=nIn // 16, bn_acti=True)
self.dconv1x3_1_2 = Conv(nIn // 16, nIn // 16, (1, dkSize), 1,
padding=(0, 1), groups=nIn // 16, bn_acti=True)
self.dconv3x1_1_3 = Conv(nIn // 16, nIn // 8, (dkSize, 1), 1,
padding=(1, 0), groups=nIn // 16, bn_acti=True)
self.dconv1x3_1_3 = Conv(nIn // 8, nIn // 8, (1, dkSize), 1,
padding=(0, 1), groups=nIn // 8, bn_acti=True)
self.dconv3x1_2_1 = Conv(nIn // 4, nIn // 16, (dkSize, 1), 1,
padding=(int(d / 4 + 1), 0), dilation=(int(d / 4 + 1), 1), groups=nIn // 16,
bn_acti=True)
self.dconv1x3_2_1 = Conv(nIn // 16, nIn // 16, (1, dkSize), 1,
padding=(0, int(d / 4 + 1)), dilation=(1, int(d / 4 + 1)), groups=nIn // 16,
bn_acti=True)
self.dconv3x1_2_2 = Conv(nIn // 16, nIn // 16, (dkSize, 1), 1,
padding=(int(d / 4 + 1), 0), dilation=(int(d / 4 + 1), 1), groups=nIn // 16,
bn_acti=True)
self.dconv1x3_2_2 = Conv(nIn // 16, nIn // 16, (1, dkSize), 1,
padding=(0, int(d / 4 + 1)), dilation=(1, int(d / 4 + 1)), groups=nIn // 16,
bn_acti=True)
self.dconv3x1_2_3 = Conv(nIn // 16, nIn // 8, (dkSize, 1), 1,
padding=(int(d / 4 + 1), 0), dilation=(int(d / 4 + 1), 1), groups=nIn // 16,
bn_acti=True)
self.dconv1x3_2_3 = Conv(nIn // 8, nIn // 8, (1, dkSize), 1,
padding=(0, int(d / 4 + 1)), dilation=(1, int(d / 4 + 1)), groups=nIn // 8,
bn_acti=True)
self.dconv3x1_3_1 = Conv(nIn // 4, nIn // 16, (dkSize, 1), 1,
padding=(int(d / 2 + 1), 0), dilation=(int(d / 2 + 1), 1), groups=nIn // 16,
bn_acti=True)
self.dconv1x3_3_1 = Conv(nIn // 16, nIn // 16, (1, dkSize), 1,
padding=(0, int(d / 2 + 1)), dilation=(1, int(d / 2 + 1)), groups=nIn // 16,
bn_acti=True)
self.dconv3x1_3_2 = Conv(nIn // 16, nIn // 16, (dkSize, 1), 1,
padding=(int(d / 2 + 1), 0), dilation=(int(d / 2 + 1), 1), groups=nIn // 16,
bn_acti=True)
self.dconv1x3_3_2 = Conv(nIn // 16, nIn // 16, (1, dkSize), 1,
padding=(0, int(d / 2 + 1)), dilation=(1, int(d / 2 + 1)), groups=nIn // 16,
bn_acti=True)
self.dconv3x1_3_3 = Conv(nIn // 16, nIn // 8, (dkSize, 1), 1,
padding=(int(d / 2 + 1), 0), dilation=(int(d / 2 + 1), 1), groups=nIn // 16,
bn_acti=True)
self.dconv1x3_3_3 = Conv(nIn // 8, nIn // 8, (1, dkSize), 1,
padding=(0, int(d / 2 + 1)), dilation=(1, int(d / 2 + 1)), groups=nIn // 8,
bn_acti=True)
self.conv1x1 = Conv(nIn, nIn, 1, 1, padding=0, bn_acti=False)
def forward(self, input):
inp = self.bn_relu_1(input)
inp = self.conv1x1_1(inp)
o1_1 = self.dconv3x1_1_1(inp)
o1_1 = self.dconv1x3_1_1(o1_1)
o1_2 = self.dconv3x1_1_2(o1_1)
o1_2 = self.dconv1x3_1_2(o1_2)
o1_3 = self.dconv3x1_1_3(o1_2)
o1_3 = self.dconv1x3_1_3(o1_3)
o2_1 = self.dconv3x1_2_1(inp)
o2_1 = self.dconv1x3_2_1(o2_1)
o2_2 = self.dconv3x1_2_2(o2_1)
o2_2 = self.dconv1x3_2_2(o2_2)
o2_3 = self.dconv3x1_2_3(o2_2)
o2_3 = self.dconv1x3_2_3(o2_3)
o3_1 = self.dconv3x1_3_1(inp)
o3_1 = self.dconv1x3_3_1(o3_1)
o3_2 = self.dconv3x1_3_2(o3_1)
o3_2 = self.dconv1x3_3_2(o3_2)
o3_3 = self.dconv3x1_3_3(o3_2)
o3_3 = self.dconv1x3_3_3(o3_3)
o4_1 = self.dconv3x1_4_1(inp)
o4_1 = self.dconv1x3_4_1(o4_1)
o4_2 = self.dconv3x1_4_2(o4_1)
o4_2 = self.dconv1x3_4_2(o4_2)
o4_3 = self.dconv3x1_4_3(o4_2)
o4_3 = self.dconv1x3_4_3(o4_3)
output_1 = torch.cat([o1_1, o1_2, o1_3], 1)
output_2 = torch.cat([o2_1, o2_2, o2_3], 1)
output_3 = torch.cat([o3_1, o3_2, o3_3], 1)
output_4 = torch.cat([o4_1, o4_2, o4_3], 1)
ad1 = output_1
ad2 = ad1 + output_2
ad3 = ad2 + output_3
ad4 = ad3 + output_4
output = torch.cat([ad1, ad2, ad3, ad4], 1)
output = self.bn_relu_2(output)
output = self.conv1x1(output)
return output + input
class CFPNet(nn.Module):
# 数据集中的类别,block_1表示第一个CFP模块的数量,block_2表示第二个CFP模块的数量
def __init__(self, classes=11, block_1=2, block_2=6):
super().__init__() # 继承父类的init()方法
# 前三个卷积块,用以特征提取
self.init_conv = nn.Sequential(
Conv(3, 32, 3, 2, padding=1, bn_acti=True),
Conv(32, 32, 3, 1, padding=1, bn_acti=True),
Conv(32, 32, 3, 1, padding=1, bn_acti=True),
)
# 论文中提到的down-sample方法,采用的是平均池化方法
self.down_1 = InputInjection(1) # down-sample the image 1 times:1/2
self.down_2 = InputInjection(2) # down-sample the image 2 times:1/4
self.down_3 = InputInjection(3) # down-sample the image 3 times:1/8
# BN+PReLU
self.bn_prelu_1 = BNPReLU(32 + 3)
# block_1中的CFP模块中的dilation_rate
dilation_block_1 = [2, 2]
# CFP Block 1
self.downsample_1 = DownSamplingBlock(32 + 3, 64)
self.CFP_Block_1 = nn.Sequential()
for i in range(0, block_1):
self.CFP_Block_1.add_module("CFP_Module_1_" + str(i), CFPModule(64, d=dilation_block_1[i]))
self.bn_prelu_2 = BNPReLU(128 + 3)
# CFP Block 2
dilation_block_2 = [4, 4, 8, 8, 16, 16] # camvid #cityscapes [4,4,8,8,16,16] # [4,8,16]
self.downsample_2 = DownSamplingBlock(128 + 3, 128)
self.CFP_Block_2 = nn.Sequential()
for i in range(0, block_2):
self.CFP_Block_2.add_module("CFP_Module_2_" + str(i),
CFPModule(128, d=dilation_block_2[i]))
self.bn_prelu_3 = BNPReLU(256 + 3)
self.classifier = nn.Sequential(Conv(259, classes, 1, 1, padding=0))
def forward(self, input):
output0 = self.init_conv(input)
# 论文中的3 time down-sample
down_1 = self.down_1(input)
down_2 = self.down_2(input)
down_3 = self.down_3(input)
# 第一次concat
output0_cat = self.bn_prelu_1(torch.cat([output0, down_1], 1))
# CFP Block 1
output1_0 = self.downsample_1(output0_cat)
output1 = self.CFP_Block_1(output1_0)
# 第二次concat
output1_cat = self.bn_prelu_2(torch.cat([output1, output1_0, down_2], 1))
# CFP Block 2
output2_0 = self.downsample_2(output1_cat)
output2 = self.CFP_Block_2(output2_0)
# 第三次concat
output2_cat = self.bn_prelu_3(torch.cat([output2, output2_0, down_3], 1))
out = self.classifier(output2_cat)
# 上采样函数,这个在opencv中也有提供
out = F.interpolate(out, input.size()[2:], mode='bilinear', align_corners=False)
return out
# 测试模块
if __name__ == "__main__":
# 创建 CFPNet 实例并进行前向传播测试
net = CFPNet()
x = torch.randn((1, 3, 512, 512))
output = net(x)
print("Output shape:", output.shape)