首先新建SE文件。
从开源网站移植(copy)自注意力代码,移植注意把nn.Module换成mmdet下的BaseModule。
from mmengine.model import BaseModule
import torch.nn as nn
class SEAttention(BaseModule):
def __init__(self,
channels,
reduction=16,
init_cfg=None,
**kwargs):
super(SEAttention, self).__init__(init_cfg=init_cfg)
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Sequential(
nn.Linear(channels, channels // reduction, False),
nn.ReLU(inplace=True),
nn.Linear(channels // reduction, channels, False),
nn.Sigmoid()
)
def forward(self, x):
b, c, _, _ = x.size()
y = self.avg_pool(x).view(b, c)
y = self.fc(y).view(b, c, 1, 1)
return x * y.expand_as(x)
SE自注意力是不会改变数据结构的,输入和输出形状一致,方便插入网络。
在SE同目录下新建FPN_SE文件,导入SE,继承FPN。
from mmdet.models.necks import FPN
from .se_attention import SEAttention
from mmdet.registry import MODELS
@MODELS.register_module()
class FPNWithSEAttention(FPN):
def __init__(self,
in_channels,
out_channels,
num_outs,
start_level=0,
end_level=-1,
add_extra_convs=False,
extra_convs_on_inputs=True,
relu_before_extra_convs=False,
no_norm_on_lateral=False,
conv_cfg=None,
norm_cfg=None,
act_cfg=None,
reduction=16,
upsample_cfg=dict(mode='nearest')):
super(FPNWithSEAttention, self).__init__(
in_channels=in_channels,
out_channels=out_channels,
num_outs=num_outs,
start_level=start_level,
end_level=end_level,
add_extra_convs=add_extra_convs,
relu_before_extra_convs=relu_before_extra_convs,
no_norm_on_lateral=no_norm_on_lateral,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
upsample_cfg=upsample_cfg)
self.se_attention = SEAttention(out_channels,reduction)
def forward(self, inputs):
outs = super(FPNWithSEAttention, self).forward(inputs)
outs = [self.se_attention(out) for out in outs]
return tuple(outs)
新建__init__.py文件,把FPN_SE导入初始化。
然后就可以在配置文件中使用了。
neck=dict(
type='FPNWithSEAttention',
in_channels=[192, 384, 768, 1536],
out_channels=256,
num_outs=5,
reduction=16 ),