一、FPN网络结构
二、FPN网络结构解释
FPN的总体架构如上图所示,主要包含自下而上网络、自上而下网络、横向连接与卷积融合4个部分。
自下而上:
最左侧为普通的卷积网络,默认使用ResNet结构,用作提取语义信息。C1代表了ResNet的前几个卷积与池化层,而C2至C5分别为不同的ResNet卷积组,这些卷积组包含了多个Bottleneck结构,组内的特征图大小相同,组间大小递减。
自上而下:
首先对C5进行1×1卷积降低通道数得到P5,然后依次进行上采样得到P4、P3和P2,目的是得到与C4、C3与C2长宽相同的特征,以方便下一步进行逐元素相加。这里采用2倍最邻近上采样,即直接对临近元素进行复制,而非线性插值。
横向连接(Lateral Connection):
目的是为了将上采样后的高语义特征与浅层的定位细节特征进行融合。高语义特征经过上采样后,其长宽与对应的浅层特征相同,而通道数固定为256,因此需要对底层特征C2至C4进行11卷积使得其通道数变为256,然后两者进行逐元素相加得到P4、P3与P2。由于C1的特征图尺寸较大且语义信息不足,因此没有把C1放到横向连接中。
卷积融合:
在得到相加后的特征后,利用3×3卷积对生成的P2至P4再进行融合,目的是消除上采样过程带来的重叠效应,以生成最终的特征图。
原文链接:https://blog.csdn.net/weixin_45564943/article/details/121643728
FPN将深层的语义信息传到底层,来补充浅层的语义信息,从而获得了高分辨率、强语义的特征,在小物体检测、实例分割等领域有着非常不俗的表现。
3、代码复现
import torch
import torch.nn as nn
def double_conv(in_channels,out_channels):
return nn.Sequential(
nn.Conv2d(in_channels,out_channels,kernel_size=3,stride=1,padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels,out_channels,kernel_size=3,stride=1,padding=1),
nn.ReLU(inplace=True)
)
class FPN(nn.Module):
def __init__(self,in_channels,out_channels):
super().__init__()
self.conv_down1=double_conv(in_channels,64)
self.conv_down2=double_conv(64,128)
self.conv_down3=double_conv(128,256)
self.conv_down4=double_conv(256,512)
self.conv_down5=double_conv(512,1024)
self.maxpool=nn.MaxPool2d(kernel_size=2,stride=2)
self.toplayer=nn.Conv2d(1024,256,kernel_size=1,stride=1,padding=0)
self.toplayer1=nn.Conv2d(512,256,kernel_size=1,stride=1,padding=0)
self.toplayer2=nn.Conv2d(256,256,kernel_size=1,stride=1,padding=0)
self.toplayer3=nn.Conv2d(128,256,kernel_size=1,stride=1,padding=0)
self.smooth=nn.Conv2d(256,256,kernel_size=3,stride=1,padding=1)
def upsample(self,x,y):
_,_,h,w=y.shape
return nn.functional.interpolate(x,size=(h,w),mode='bilinear',align_corners=True)+y
def forward(self,x):
c1=self.maxpool(self.conv_down1(x))
c2=self.maxpool(self.conv_down2(c1))
c3=self.maxpool(self.conv_down3(c2))
c4=self.maxpool(self.conv_down4(c3))
c5=self.maxpool(self.conv_down5(c4))
p5= self.toplayer(c5)
p4 = self.upsample(p5,self.toplayer1(c4))
p3 = self.upsample(p4,self.toplayer2(c3))
p2 = self.upsample(p3,self.toplayer3(c2))
p4 = self.smooth(p4)
p3 = self.smooth(p3)
p2 = self.smooth(p2)
return p4,p3,p2
fpn=FPN(3,256)
p4,p3,p2=fpn(torch.randn(1,3,256,256))
print(p4.shape)
print(p3.shape)
print(p2.shape)