MMDetection中对Resnet增加注意力机制Attention的简单方法

代码

import torch
from mmdet.models.backbones import ResNet
from fightingcv_attention.attention.CoordAttention import CoordAtt
from fightingcv_attention.attention.SEAttention import SEAttention
from mmdet.models.builder import BACKBONES
 
 
# 定义带attention的resnet18基类
class ResNetWithAttention(ResNet):
    def __init__(self , **kwargs):
        super(ResNetWithAttention, self).__init__(**kwargs)
        # 目前将注意力模块加在最后的三个输出特征层
        # resnet输出四个特征层
        if self.depth in (18, 34):
            self.dims = (64, 128, 256, 512)
        elif self.depth in (50, 101, 152):
            self.dims = (256, 512, 1024, 2048)
        else:
            raise Exception()
        self.attention1 = self.get_attention_module(self.dims[1])     
        self.attention2 = self.get_attention_module(self.dims[2])     
        self.attention3 = self.get_attention_module(self.dims[3])     
    
    # 子类只需要实现该attention即可
    def get_attention_module(self, dim):
        raise NotImplementedError()
    
    def forward(self, x):
        outs = super().forward(x)
        outs = list(outs)
        outs[1] = self.attention1(outs[1])
        outs[2] = self.attention2(outs[2])
        outs[3] = self.attention3(outs[3])    
        outs = tuple(outs)
        return outs
    
@BACKBONES.register_module()
class ResNetWithCoordAttention(ResNetWithAttention):
    def __init__(self , **kwargs):
        super(ResNetWithCoordAttention, self).__init__(**kwargs)
 
    # 子类只需要实现该attention即可
    def get_attention_module(self, dim):
        return CoordAtt(inp=dim, oup=dim, reduction=32)
    
@BACKBONES.register_module()
class ResNetWithSEAttention(ResNetWithAttention):
    def __init__(self , **kwargs):
        super(ResNetWithSEAttention, self).__init__(**kwargs)
 
    # 子类只需要实现该attention即可
    def get_attention_module(self, dim):
        return SEAttention(channel=dim, reduction=16)
 
 
if __name__ == "__main__":
    # model = ResNet(depth=18)
    # model = ResNet(depth=34)
    # model = ResNet(depth=50)
    # model = ResNet(depth=101)    
    # model = ResNet(depth=152)
    # model = ResNetWithCoordAttention(depth=18)
    model = ResNetWithSEAttention(depth=18)
    x = torch.rand(1, 3, 224, 224)
    outs = model(x)
    # print(outs.shape)
    for i, out in enumerate(outs):
        print(i, out.shape)

以resnet为例子,我在多个尺度的特征层输出增加注意力机制,以此编写一个基类,子类只需要实现这个attention即可。

参考开源仓库实现attention:

GitHub - xmu-xiaoma666/External-Attention-pytorch: 🍀 Pytorch implementation of various Attention Mechanisms, MLP, Re-parameter, Convolution, which is helpful to further understand papers.⭐⭐⭐

当然也可以直接pip调用:

pip install fightingcv-attention

测试完模型输出后可以利用注册到mmdetection:

简单的方法是,添加backbone注册修饰器,并在train.py和test.py中,import 该文件。

在配置上将model的type从Resnet更改为ResNetWithSEAttention或者ResNetWithSEAttention即可。

注意

(1)这是基于mmdet 2.x版本的改动。

(2)需要注意增加网络的输入输出形状,可能还需要考虑不同宽高输入时网络输出是否与设计一致。

(3)注意力放各尺寸末端是因为想尽可能使用预训练权重(提高预训练权重利用率),如果是重新训练的话,放哪里效果最佳我无法确定。

(4)增加注意力机制不一定有效涨点,只能是尝试一下。注意力机制解决的问题如果与数据对应,可以尝试一下。譬如需要增加空间长距离特征提取能力可以用空间注意力,或者简单的使用通道注意力都行。如果你的数据量少或者任务较简单,尽量不要用参数较多的模型。

  • 13
    点赞
  • 84
    收藏
    觉得还不错? 一键收藏
  • 24
    评论
评论 24
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值