【YOLO改进】主干插入Spatial Group-wise Enhance(SGE)模块(基于MMYOLO)

Spatial Group-wise Enhance(SGE)模块

论文链接:https://arxiv.org/abs/1905.09646

将Spatial Group-wise Enhance(SGE)模块添加到MMYOLO中

  1. 将开源代码SGE.py文件复制到mmyolo/models/plugins目录下

  2. 导入MMYOLO用于注册模块的包: from mmyolo.registry import MODELS

  3. 确保 class SpatialGroupEnhance中的输入维度为in_channels(因为MMYOLO会提前传入输入维度参数,所以要保持参数名的一致)

  4. 利用@MODELS.register_module()将“class SpatialGroupEnhance(nn.Module)”注册:

  5. 修改mmyolo/models/plugins/__init__.py文件

  6. 在终端运行:

    python setup.py install
  7. 修改对应的配置文件,并且将plugins的参数“type”设置为“SpatialGroupEnhance”,可参考【YOLO改进】主干插入注意力机制模块CBAM(基于MMYOLO)-CSDN博客

修改后的SGE.py

import torch
from torch import nn
from mmyolo.registry import MODELS

@MODELS.register_module()
class SequentialPolarizedSelfAttention(nn.Module):

    def __init__(self, in_channels=512):
        super().__init__()
        self.ch_wv=nn.Conv2d(in_channels,in_channels//2,kernel_size=(1,1))
        self.ch_wq=nn.Conv2d(in_channels,1,kernel_size=(1,1))
        self.softmax_channel=nn.Softmax(1)
        self.softmax_spatial=nn.Softmax(-1)
        self.ch_wz=nn.Conv2d(in_channels//2,in_channels,kernel_size=(1,1))
        self.ln=nn.LayerNorm(in_channels)
        self.sigmoid=nn.Sigmoid()
        self.sp_wv=nn.Conv2d(in_channels,in_channels//2,kernel_size=(1,1))
        self.sp_wq=nn.Conv2d(in_channels,in_channels//2,kernel_size=(1,1))
        self.agp=nn.AdaptiveAvgPool2d((1,1))

    def forward(self, x):
        b, c, h, w = x.size()

        #Channel-only Self-Attention
        channel_wv=self.ch_wv(x) #bs,c//2,h,w
        channel_wq=self.ch_wq(x) #bs,1,h,w
        channel_wv=channel_wv.reshape(b,c//2,-1) #bs,c//2,h*w
        channel_wq=channel_wq.reshape(b,-1,1) #bs,h*w,1
        channel_wq=self.softmax_channel(channel_wq)
        channel_wz=torch.matmul(channel_wv,channel_wq).unsqueeze(-1) #bs,c//2,1,1
        channel_weight=self.sigmoid(self.ln(self.ch_wz(channel_wz).reshape(b,c,1).permute(0,2,1))).permute(0,2,1).reshape(b,c,1,1) #bs,c,1,1
        channel_out=channel_weight*x

        #Spatial-only Self-Attention
        spatial_wv=self.sp_wv(channel_out) #bs,c//2,h,w
        spatial_wq=self.sp_wq(channel_out) #bs,c//2,h,w
        spatial_wq=self.agp(spatial_wq) #bs,c//2,1,1
        spatial_wv=spatial_wv.reshape(b,c//2,-1) #bs,c//2,h*w
        spatial_wq=spatial_wq.permute(0,2,3,1).reshape(b,1,c//2) #bs,1,c//2
        spatial_wq=self.softmax_spatial(spatial_wq)
        spatial_wz=torch.matmul(spatial_wq,spatial_wv) #bs,1,h*w
        spatial_weight=self.sigmoid(spatial_wz.reshape(b,1,h,w)) #bs,1,h,w
        spatial_out=spatial_weight*channel_out
        return spatial_out

if __name__ == '__main__':
    input=torch.randn(1,512,7,7)
    psa = SequentialPolarizedSelfAttention(channel=
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值