pytorch代码实现注意力机制之MLCA

MLCA注意力机制

简要:注意力机制是计算机视觉中使用最广泛的组件之一,可以帮助神经网络强调重要元素并抑制不相关的元素。然而,绝大多数信道注意力机制仅包含信道特征信息而忽略了空间特征信息,导致模型表示效果或目标检测性能较差,空间注意力模块往往复杂且成本高昂。为了在性能和复杂度之间取得平衡,该文提出一种轻量级的混合本地信道注意(MLCA)模块来提升目标检测网络的性能,该模块可以同时结合信道信息和空间信息,以及局部信息和全局信息来提高网络的表达效果。在此基础上,提出了用于比较各种注意力模块性能的MobileNet-Attention-YOLO(MAY)算法。在 Pascal VOC 和 SMID 数据集上,MLCA 在模型表示的功效、性能和复杂性之间实现了比其他注意力技术更好的平衡。与PASCAL VOC数据集上的Squeeze-and-Excitation(SE)注意力机制和SIMD数据集上的Coordinate Attention(CA)方法相比,mAP分别提高了1.0%和1.5%。

原文地址:Mixed local channel attention for object detection

结构图

pytorch代码实现MLCA

import math, torch
from torch import nn
import torch.nn.functional as F

class MLCA(nn.Module):
    def __init__(self, in_size, local_size=5, gamma = 2, b = 1,local_weight=0.5):
        super(MLCA, self).__init__()

        # ECA 计算方法
        self.local_size=local_size
        self.gamma = gamma
        self.b = b
        t = int(abs(math.log(in_size, 2) + self.b) / self.gamma)   # eca  gamma=2
        k = t if t % 2 else t + 1

        self.conv = nn.Conv1d(1, 1, kernel_size=k, padding=(k - 1) // 2, bias=False)
        self.conv_local = nn.Conv1d(1, 1, kernel_size=k, padding=(k - 1) // 2, bias=False)

        self.local_weight=local_weight

        self.local_arv_pool = nn.AdaptiveAvgPool2d(local_size)
        self.global_arv_pool=nn.AdaptiveAvgPool2d(1)

    def forward(self, x):
        local_arv=self.local_arv_pool(x)
        global_arv=self.global_arv_pool(local_arv)

        b,c,m,n = x.shape
        b_local, c_local, m_local, n_local = local_arv.shape

        # (b,c,local_size,local_size) -> (b,c,local_size*local_size) -> (b,local_size*local_size,c) -> (b,1,local_size*local_size*c)
        temp_local= local_arv.view(b, c_local, -1).transpose(-1, -2).reshape(b, 1, -1)
        # (b,c,1,1) -> (b,c,1) -> (b,1,c)
        temp_global = global_arv.view(b, c, -1).transpose(-1, -2)

        y_local = self.conv_local(temp_local)
        y_global = self.conv(temp_global)

        # (b,c,local_size,local_size) <- (b,c,local_size*local_size)<-(b,local_size*local_size,c) <- (b,1,local_size*local_size*c)
        y_local_transpose=y_local.reshape(b, self.local_size * self.local_size,c).transpose(-1,-2).view(b, c, self.local_size , self.local_size)
        # (b,1,c) -> (b,c,1) -> (b,c,1,1)
        y_global_transpose = y_global.transpose(-1,-2).unsqueeze(-1)

        # 反池化
        att_local = y_local_transpose.sigmoid()
        att_global = F.adaptive_avg_pool2d(y_global_transpose.sigmoid(),[self.local_size, self.local_size])
        att_all = F.adaptive_avg_pool2d(att_global*(1-self.local_weight)+(att_local*self.local_weight), [m, n])

        x = x * att_all
        return x

if __name__ == '__main__':
    attention = MLCA(in_size=256)
    inputs = torch.randn((2, 256, 16, 16))
    result = attention(inputs)
    print(result.size())
  • 13
    点赞
  • 14
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

我悟了-

你的激励是我肝下去的动力~

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值