介绍
论文:https://arxiv.org/abs/1911.09070v7
代码:https://github.com/google/automl/tree/master/efficientdet
复现代码:https://github.com/jewelc92/mmdetection/blob/3.x/projects/EfficientDet/efficientdet/bifpn.py
原理:
- 高效的双向跨尺度连接
- 加权特征图融合
图2:特征网络设计-(a)FPN引入自上而下的路径以融合从第3级到第7级(P3 - P7)的多尺度特征;(B)PANet在FPN之上添加了额外的自下而上的途径;(c)NAS-FPN使用神经架构搜索来找到不规则的特征网络拓扑,然后重复应用相同的块;(d)是我们的具有更好的准确性和效率权衡的BiFPN。
论文提出了几种针对跨尺度连接的优化方法:
- 首先,删去那些只有一个输入的节点,因为如果一个节点只有一个输入没有特征融合的过程,那么它对旨在融合不同特征的网络的贡献就会比较小。
- 其次,如果原始输入和输出节点处于同一层级,增加一条额外的输入路径,从而在不增加太多计算成本的情况下融合更多的特征。
- 最后,与PANet只有一个自上而下和一个自下而上的路径不同,我们将每个双向路径视为一个特征网络层,并重复多次,从而实现更高级的特征融合。
当融合不同分辨率的特征时,常见的方法是将它们的分辨率调整为相等大小,然后进行相加。之前的方法都平等的对待不同的输入特征,但作者观察到,由于不同的特征具有不同的分辨率,通常它们对于输出的贡献也不相同。因此本文提出对于每个输入添加一个额外的权重,让网络来学习每个输入特征的重要性。
代码
# Copyright (c) OpenMMLab. All rights reserved.
# Modified from https://github.com/zylo117/Yet-Another-EfficientDet-Pytorch
from typing import List
import torch
import torch.nn as nn
from mmcv.cnn.bricks import Swish
from mmengine.model import BaseModule
from mmdet.registry import MODELS
from mmdet.utils import MultiConfig, OptConfigType
from .utils import (DepthWiseConvBlock, DownChannelBlock, MaxPool2dSamePadding,
MemoryEfficientSwish)
class BiFPNStage(nn.Module):
'''
in_channels: List[int], input dim for P3, P4, P5
out_channels: int, output dim for P2 - P7
first_time: int, whether is the first bifpnstage
num_outs: int, BiFPN need feature maps num
use_swish: whether use MemoryEfficientSwish
norm_cfg: (:obj:`ConfigDict` or dict, optional): Config dict for
normalization layer.
epsilon: float, hyperparameter in fusion features
'''
def __init__(self,
in_channels: List[int],
out_channels: int,
first_time: bool = False,
apply_bn_for_resampling: bool = True,
conv_bn_act_pattern: bool = False,
use_meswish: bool = True,
norm_cfg: OptConfigType = dict(
type='BN', momentum=1e-2, eps=1e-3),
epsilon: float = 1e-4) -> None:
super().__init__()
assert isinstance(in_channels, list)
self.in_channels = in_channels
self.out_channels = out_channels
self.first_time = first_time
self.apply_bn_for_resampling = apply_bn_for_resampling
self.conv_bn_act_pattern = conv_bn_act_pattern
self.use_meswish = use_meswish
self.norm_cfg = norm_cfg
self.epsilon = epsilon
if self.first_time:
self.p5_down_channel = DownChannelBlock(
self.in_channels[-1],
self.out_channels,
apply_norm=self.apply_bn_for_resampling,
conv_bn_act_pattern=self.conv_bn_act_pattern,
norm_cfg=norm_cfg)
self.p4_down_channel = DownChannelBlock(
self.in_channels[-2],
self.out_channels,
apply_norm=self.apply_bn_for_resampling,
conv_bn_act_pattern=self.conv_bn_act_pattern,
norm_cfg=norm_cfg)
self.p3_down_channel = DownChannelBlock(
self.in_channels[-3],
self.out_channels,
apply_norm=self.apply_bn_for_resampling,
conv_bn_act_pattern=self.conv_bn_act_pattern,
norm_cfg=norm_cfg)
self.p5_to_p6 = nn.Sequential(
DownChannelBlock(
self.in_channels[-1],
self.out_channels,
apply_norm=self.apply_bn_for_resampling,
conv_bn_act_pattern=self.conv_bn_act_pattern,
norm_cfg=norm_cfg), MaxPool2dSamePadding(3, 2))
self.p6_to_p7 = MaxPool2dSamePadding(3, 2)
self.p4_level_connection = DownChannelBlock(
self.in_channels[-2],
self.out_channels,
apply_norm=self.apply_bn_for_resampling,
conv_bn_act_pattern=self.conv_bn_act_pattern,
norm_cfg=norm_cfg)
self.p5_level_connection = DownChannelBlock(
self.in_channels[-1],
self.out_channels,
apply_norm=self.apply_bn_for_resampling,
conv_bn_act_pattern=self.conv_bn_act_pattern,
norm_cfg=norm_cfg)
self.p6_upsample = nn.Upsample(scale_factor=2, mode='nearest')
self.p5_upsample = nn.Upsample(scale_factor=2, mode='nearest')
self.p4_upsample = nn.Upsample(scale_factor=2, mode='nearest')
self.p3_upsample = nn.Upsample(scale_factor=2, mode='nearest')
# bottom to up: feature map down_sample module
self.p4_down_sample = MaxPool2dSamePadding(3, 2)
self.p5_down_sample = MaxPool2dSamePadding(3, 2)
self.p6_down_sample = MaxPool2dSamePadding(3, 2)
self.p7_down_sample = MaxPool2dSamePadding(3, 2)
# Fuse Conv Layers
self.conv6_up = DepthWiseConvBlock(
out_channels,
out_channels,
apply_norm=self.apply_bn_for_resampling,
conv_bn_act_pattern=self.conv_bn_act_pattern,
norm_cfg=norm_cfg)
self.conv5_up = DepthWiseConvBlock(
out_channels,
out_channels,
apply_norm=self.apply_bn_for_resampling,
conv_bn_act_pattern=self.conv_bn_act_pattern,
norm_cfg=norm_cfg)
self.conv4_up = DepthWiseConvBlock(
out_channels,
out_channels,
apply_norm=self.apply_bn_for_resampling,
conv_bn_act_pattern=self.conv_bn_act_pattern,
norm_cfg=norm_cfg)
self.conv3_up = DepthWiseConvBlock(
out_channels,
out_channels,
apply_norm=self.apply_bn_for_resampling,
conv_bn_act_pattern=self.conv_bn_act_pattern,
norm_cfg=norm_cfg)
self.conv4_down = DepthWiseConvBlock(
out_channels,
out_channels,
apply_norm=self.apply_bn_for_resampling,
conv_bn_act_pattern=self.conv_bn_act_pattern,
norm_cfg=norm_cfg)
self.conv5_down = DepthWiseConvBlock(
out_channels,
out_channels,
apply_norm=self.apply_bn_for_resampling,
conv_bn_act_pattern=self.conv_bn_act_pattern,
norm_cfg=norm_cfg)
self.conv6_down = DepthWiseConvBlock(
out_channels,
out_channels,
apply_norm=self.apply_bn_for_resampling,
conv_bn_act_pattern=self.conv_bn_act_pattern,
norm_cfg=norm_cfg)
self.conv7_down = DepthWiseConvBlock(
out_channels,
out_channels,
apply_norm=self.apply_bn_for_resampling,
conv_bn_act_pattern=self.conv_bn_act_pattern,
norm_cfg=norm_cfg)
# weights
self.p6_w1 = nn.Parameter(
torch.ones(2, dtype=torch.float32), requires_grad=True)
self.p6_w1_relu = nn.ReLU()
self.p5_w1 = nn.Parameter(
torch.ones(2, dtype=torch.float32), requires_grad=True)
self.p5_w1_relu = nn.ReLU()
self.p4_w1 = nn.Parameter(
torch.ones(2, dtype=torch.float32), requires_grad=True)
self.p4_w1_relu = nn.ReLU()
self.p3_w1 = nn.Parameter(
torch.ones(2, dtype=torch.float32), requires_grad=True)
self.p3_w1_relu = nn.ReLU()
self.p4_w2 = nn.Parameter(
torch.ones(3, dtype=torch.float32), requires_grad=True)
self.p4_w2_relu = nn.ReLU()
self.p5_w2 = nn.Parameter(
torch.ones(3, dtype=torch.float32), requires_grad=True)
self.p5_w2_relu = nn.ReLU()
self.p6_w2 = nn.Parameter(
torch.ones(3, dtype=torch.float32), requires_grad=True)
self.p6_w2_relu = nn.ReLU()
self.p7_w2 = nn.Parameter(
torch.ones(2, dtype=torch.float32), requires_grad=True)
self.p7_w2_relu = nn.ReLU()
self.swish = MemoryEfficientSwish() if use_meswish else Swish()
def combine(self, x):
if not self.conv_bn_act_pattern:
x = self.swish(x)
return x
def forward(self, x):
if self.first_time:
p3, p4, p5 = x # [(1,40,64,64),(1,112,32,32),(1,320,16,16)]
# build feature map P6
p6_in = self.p5_to_p6(p5) # (1,64,8,8)
# build feature map P7
p7_in = self.p6_to_p7(p6_in) # (1,64,4,4)
p3_in = self.p3_down_channel(p3) # (1,64,64,64)
p4_in = self.p4_down_channel(p4) # (1,64,32,32)
p5_in = self.p5_down_channel(p5) # (1,64,16,16)
else:
p3_in, p4_in, p5_in, p6_in, p7_in = x
# Weights for P6_0 and P7_0 to P6_1
p6_w1 = self.p6_w1_relu(self.p6_w1)
weight = p6_w1 / (torch.sum(p6_w1, dim=0) + self.epsilon)
# Connections for P6_0 and P7_0 to P6_1 respectively
p6_up = self.conv6_up(
self.combine(weight[0] * p6_in +
weight[1] * self.p6_upsample(p7_in))) # (1,64,8,8)
# Weights for P5_0 and P6_1 to P5_1
p5_w1 = self.p5_w1_relu(self.p5_w1)
weight = p5_w1 / (torch.sum(p5_w1, dim=0) + self.epsilon)
# Connections for P5_0 and P6_1 to P5_1 respectively
p5_up = self.conv5_up(
self.combine(weight[0] * p5_in +
weight[1] * self.p5_upsample(p6_up))) # (1,64,16,16)
# Weights for P4_0 and P5_1 to P4_1
p4_w1 = self.p4_w1_relu(self.p4_w1)
weight = p4_w1 / (torch.sum(p4_w1, dim=0) + self.epsilon)
# Connections for P4_0 and P5_1 to P4_1 respectively
p4_up = self.conv4_up(
self.combine(weight[0] * p4_in +
weight[1] * self.p4_upsample(p5_up))) # (1,64,32,32)
# Weights for P3_0 and P4_1 to P3_2
p3_w1 = self.p3_w1_relu(self.p3_w1)
weight = p3_w1 / (torch.sum(p3_w1, dim=0) + self.epsilon)
# Connections for P3_0 and P4_1 to P3_2 respectively
p3_out = self.conv3_up(
self.combine(weight[0] * p3_in +
weight[1] * self.p3_upsample(p4_up))) # (1,64,64,64)
if self.first_time:
# self.p4_level_connection和self.p4_down_channel是一样的,为什么不能直接用上面的p4_in?
p4_in = self.p4_level_connection(p4)
p5_in = self.p5_level_connection(p5)
# Weights for P4_0, P4_1 and P3_2 to P4_2
p4_w2 = self.p4_w2_relu(self.p4_w2)
weight = p4_w2 / (torch.sum(p4_w2, dim=0) + self.epsilon)
# Connections for P4_0, P4_1 and P3_2 to P4_2 respectively
p4_out = self.conv4_down(
self.combine(weight[0] * p4_in + weight[1] * p4_up +
weight[2] * self.p4_down_sample(p3_out))) # (1,64,32,32)
# Weights for P5_0, P5_1 and P4_2 to P5_2
p5_w2 = self.p5_w2_relu(self.p5_w2)
weight = p5_w2 / (torch.sum(p5_w2, dim=0) + self.epsilon)
# Connections for P5_0, P5_1 and P4_2 to P5_2 respectively
p5_out = self.conv5_down(
self.combine(weight[0] * p5_in + weight[1] * p5_up +
weight[2] * self.p5_down_sample(p4_out))) # (1,64,16,16)
# Weights for P6_0, P6_1 and P5_2 to P6_2
p6_w2 = self.p6_w2_relu(self.p6_w2)
weight = p6_w2 / (torch.sum(p6_w2, dim=0) + self.epsilon)
# Connections for P6_0, P6_1 and P5_2 to P6_2 respectively
p6_out = self.conv6_down(
self.combine(weight[0] * p6_in + weight[1] * p6_up +
weight[2] * self.p6_down_sample(p5_out))) # (1,64,8,8)
# Weights for P7_0 and P6_2 to P7_2
p7_w2 = self.p7_w2_relu(self.p7_w2)
weight = p7_w2 / (torch.sum(p7_w2, dim=0) + self.epsilon)
# Connections for P7_0 and P6_2 to P7_2
p7_out = self.conv7_down(
self.combine(weight[0] * p7_in +
weight[1] * self.p7_down_sample(p6_out))) # (1,64,4,4)
return p3_out, p4_out, p5_out, p6_out, p7_out
@MODELS.register_module()
class BiFPN(BaseModule):
'''
num_stages: int, bifpn number of repeats
in_channels: List[int], input dim for P3, P4, P5
out_channels: int, output dim for P2 - P7
start_level: int, Index of input features in backbone
epsilon: float, hyperparameter in fusion features
apply_bn_for_resampling: bool, whether use bn after resampling
conv_bn_act_pattern: bool, whether use conv_bn_act_pattern
use_swish: whether use MemoryEfficientSwish
norm_cfg: (:obj:`ConfigDict` or dict, optional): Config dict for
normalization layer.
init_cfg: MultiConfig: init method
'''
def __init__(self,
num_stages: int,
in_channels: List[int],
out_channels: int,
start_level: int = 0,
epsilon: float = 1e-4,
apply_bn_for_resampling: bool = True,
conv_bn_act_pattern: bool = False,
use_meswish: bool = True,
norm_cfg: OptConfigType = dict(
type='BN', momentum=1e-2, eps=1e-3),
init_cfg: MultiConfig = None) -> None:
super().__init__(init_cfg=init_cfg)
self.start_level = start_level
self.bifpn = nn.Sequential(*[
BiFPNStage(
in_channels=in_channels,
out_channels=out_channels,
first_time=True if _ == 0 else False,
apply_bn_for_resampling=apply_bn_for_resampling,
conv_bn_act_pattern=conv_bn_act_pattern,
use_meswish=use_meswish,
norm_cfg=norm_cfg,
epsilon=epsilon) for _ in range(num_stages)
])
def forward(self, x):
# [(1,40,64,64),(1,112,32,32),(1,320,16,16)]
x = x[self.start_level:]
x = self.bifpn(x)
return x