mmdetection3增加12种注意力机制

在mmdetection/mmdet/models/layers/目录下增加attention_layers.py

import torch.nn as nn
from mmdet.registry import MODELS
#自定义注意力机制算法
from .attention.CBAM import CBAMBlock as _CBAMBlock
from .attention.BAM import BAMBlock as _BAMBlock
from .attention.SEAttention import SEAttention as _SEAttention
from .attention.ECAAttention import ECAAttention as _ECAAttention
from .attention.ShuffleAttention import ShuffleAttention as _ShuffleAttention
from .attention.SGE import SpatialGroupEnhance as _SpatialGroupEnhance
from .attention.A2Atttention import DoubleAttention as _DoubleAttention
from .attention.PolarizedSelfAttention import SequentialPolarizedSelfAttention as _SequentialPolarizedSelfAttention
from .attention.CoTAttention import CoTAttention as _CoTAttention
from .attention.TripletAttention import TripletAttention as _TripletAttention
from .attention.CoordAttention import CoordAtt as _CoordAtt
from .attention.ParNetAttention import ParNetAttention as _ParNetAttention


@MODELS.register_module()
class CBAMBlock(nn.Module):
    
    def __init__(self, in_channels, **kwargs):
        super(CBAMBlock, self).__init__()
        print("======激活注意力机制模块【CBAMBlock】======")
        self.module = _CBAMBlock(channel = in_channels, **kwargs)

    def forward(self, x):
        return self.module(x)
    
    
@MODELS.register_module()
class BAMBlock(nn.Module):
    
    def __init__(self, in_channels, **kwargs):
        super(BAMBlock, self).__init__()
        print("======激活注意力机制模块【BAMBlock】======")
        self.module = _BAMBlock(channel = in_channels, **kwargs)

    def forward(self, x):
        return self.module(x)


@MODELS.register_module()
class SEAttention(nn.Module):
    
    def __init__(self, in_channels, **kwargs):
        super(SEAttention, self).__init__()
        print("======激活注意力机制模块【SEAttention】======")
        self.module = _SEAttention(channel = in_channels, **kwargs)

    def forward(self, x):
        return self.module(x)   
 

@MODELS.register_module()
class ECAAttention(nn.Module):
    
    def __init__(self, in_channels, **kwargs):
        super(ECAAttention, self).__init__()
        print("======激活注意力机制模块【ECAAttention】======")
        self.module = _ECAAttention(**kwargs)

    def forward(self, x):
        return self.module(x)  


@MODELS.register_module()
class ShuffleAttention(nn.Module):
    
    def __init__(self, in_channels, **kwargs):
        super(ShuffleAttention, self).__init__()
        print("======激活注意力机制模块【ShuffleAttention】======")
        self.module = _ShuffleAttention(channel = in_channels, **kwargs)

    def forward(self, x):
        return self.module(x)


@MODELS.register_module()
class SpatialGroupEnhance(nn.Module):
    
    def __init__(self, in_channels, **kwargs):
        super(SpatialGroupEnhance, self).__init__()
        print("======激活注意力机制模块【SpatialGroupEnhance】======")
        self.module = _SpatialGroupEnhance(**kwargs)

    def forward(self, x):
        return self.module(x)   
    

@MODELS.register_module()
class DoubleAttention(nn.Module):
    
    def __init__(self, in_channels, **kwargs):
        super(DoubleAttention, self).__init__()
        print("======激活注意力机制模块【DoubleAttention】======")
        self.module = _DoubleAttention(in_channels, 128, 128,True)

    def forward(self, x):
        return self.module(x)  


@MODELS.register_module()
class SequentialPolarizedSelfAttention(nn.Module):
    
    def __init__(self, in_channels, **kwargs):
        super(SequentialPolarizedSelfAttention, self).__init__()
        print("======激活注意力机制模块【Polarized Self-Attention】======")
        self.module = _SequentialPolarizedSelfAttention(channel=in_channels)

    def forward(self, x):
        return self.module(x)   
    
    
@MODELS.register_module()
class CoTAttention(nn.Module):
    
    def __init__(self, in_channels, **kwargs):
        super(CoTAttention, self).__init__()
        print("======激活注意力机制模块【CoTAttention】======")
        self.module = _CoTAttention(dim=in_channels, **kwargs)

    def forward(self, x):
        return self.module(x)  

    
@MODELS.register_module()
class TripletAttention(nn.Module):
    
    def __init__(self, in_channels, **kwargs):
        super(TripletAttention, self).__init__()
        print("======激活注意力机制模块【TripletAttention】======")
        self.module = _TripletAttention()

    def forward(self, x):
        return self.module(x)      


@MODELS.register_module()
class CoordAtt(nn.Module):
    
    def __init__(self, in_channels, **kwargs):
        super(CoordAtt, self).__init__()
        print("======激活注意力机制模块【CoordAtt】======")
        self.module = _CoordAtt(in_channels, in_channels, **kwargs)

    def forward(self, x):
        return self.module(x)    


@MODELS.register_module()
class ParNetAttention(nn.Module):
    
    def __init__(self, in_channels, **kwargs):
        super(ParNetAttention, self).__init__()
        print("======激活注意力机制模块【ParNetAttention】======")
        self.module = _ParNetAttention(channel=in_channels)

    def forward(self, x):
        return self.module(x)  

与attention_layers.py同级目录下创建attention文件夹,在attention文件中放12种注意力机制算法文件。

下载地址:mmdetection3的12种注意力机制资源-CSDN文库icon-default.png?t=N7T8https://download.csdn.net/download/lanyan90/89513979

使用方法:

以faster-rcnn_r50为例,创建faster-rcnn_r50_fpn_1x_coco_attention.py

_base_ = 'configs/detection/faster_rcnn/faster-rcnn_r50_fpn_1x_coco.py'

custom_imports = dict(imports=['mmdet.models.layers.attention_layers'], allow_failed_imports=False)

model = dict(
    backbone=dict(
        plugins = [
            dict(
                position='after_conv3',
                #cfg = dict(type='CBAMBlock', reduction=16, kernel_size=7)
                #cfg = dict(type='BAMBlock', reduction=16, dia_val=1)
                #cfg = dict(type='SEAttention', reduction=8)
                #cfg = dict(type='ECAAttention', kernel_size=3)
                #cfg = dict(type='ShuffleAttention', G=8)
                #cfg = dict(type='SpatialGroupEnhance', groups=8)
                #cfg = dict(type='DoubleAttention')
                #cfg = dict(type='SequentialPolarizedSelfAttention')
                #cfg = dict(type='CoTAttention', kernel_size=3)
                #cfg = dict(type='TripletAttention')
                #cfg = dict(type='CoordAtt', reduction=32)
                #cfg = dict(type='ParNetAttention')
            )
        ]
    )
)

想使用哪种注意力机制,放开plugins中的注释即可。

以mask-rcnn_r50为例,创建mask-rcnn_r50_fpn_1x_coco_attention.py

_base_ = 'configs/segmentation/mask_rcnn/mask-rcnn_r50_fpn_1x_coco.py'
custom_imports = dict(imports=['mmdet.models.layers.attention_layers'], allow_failed_imports=False)

model = dict(
    backbone=dict(
        plugins = [
            dict(
                position='after_conv3',
                #cfg = dict(type='CBAMBlock', reduction=16, kernel_size=7)
                #cfg = dict(type='BAMBlock', reduction=16, dia_val=1)
                #cfg = dict(type='SEAttention', reduction=8)
                #cfg = dict(type='ECAAttention', kernel_size=3)
                #cfg = dict(type='ShuffleAttention', G=8)
                #cfg = dict(type='SpatialGroupEnhance', groups=8)
                #cfg = dict(type='DoubleAttention')
                #cfg = dict(type='SequentialPolarizedSelfAttention')
                #cfg = dict(type='CoTAttention', kernel_size=3)
                #cfg = dict(type='TripletAttention')
                #cfg = dict(type='CoordAtt', reduction=32)
                #cfg = dict(type='ParNetAttention')
            )
        ]
    )
)

用法一样!

  • 2
    点赞
  • 3
    收藏
    觉得还不错? 一键收藏
  • 3
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值