学习笔记0:mmdet在FPN上增加SE自注意力

首先新建SE文件。

从开源网站移植(copy)自注意力代码,移植注意把nn.Module换成mmdet下的BaseModule。


from mmengine.model import BaseModule
import torch.nn as nn
class SEAttention(BaseModule):
    def __init__(self,
                 channels,
                 reduction=16,
                 init_cfg=None,
                 **kwargs):
        super(SEAttention, self).__init__(init_cfg=init_cfg)
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc = nn.Sequential(
            nn.Linear(channels, channels // reduction, False),
            nn.ReLU(inplace=True),
            nn.Linear(channels // reduction, channels, False),
            nn.Sigmoid()
        )

    def forward(self, x):
        b, c, _, _ = x.size()
        y = self.avg_pool(x).view(b, c)
        y = self.fc(y).view(b, c, 1, 1)
        return x * y.expand_as(x)

SE自注意力是不会改变数据结构的,输入和输出形状一致,方便插入网络。

在SE同目录下新建FPN_SE文件,导入SE,继承FPN。

from mmdet.models.necks import FPN
from .se_attention import SEAttention
from mmdet.registry import MODELS
@MODELS.register_module()
class FPNWithSEAttention(FPN):
    def __init__(self,
                in_channels,
                out_channels,
                num_outs,
                start_level=0,
                end_level=-1,
                add_extra_convs=False,
                extra_convs_on_inputs=True,
                relu_before_extra_convs=False,
                no_norm_on_lateral=False,
                conv_cfg=None,
                norm_cfg=None,
                act_cfg=None,
                reduction=16,
                upsample_cfg=dict(mode='nearest')):
        super(FPNWithSEAttention, self).__init__(
            in_channels=in_channels,
            out_channels=out_channels,
            num_outs=num_outs,
            start_level=start_level,
            end_level=end_level,
            add_extra_convs=add_extra_convs,
            relu_before_extra_convs=relu_before_extra_convs,
            no_norm_on_lateral=no_norm_on_lateral,
            conv_cfg=conv_cfg,
            norm_cfg=norm_cfg,
            act_cfg=act_cfg,
            upsample_cfg=upsample_cfg)
        self.se_attention = SEAttention(out_channels,reduction)

    def forward(self, inputs):
        outs = super(FPNWithSEAttention, self).forward(inputs)
        outs = [self.se_attention(out) for out in outs]
        return tuple(outs)

新建__init__.py文件,把FPN_SE导入初始化。

然后就可以在配置文件中使用了。

neck=dict(
        type='FPNWithSEAttention',
        in_channels=[192, 384, 768, 1536],
        out_channels=256,
        num_outs=5,
        reduction=16            ),

  • 7
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值