Global Attention Mechanism: Retain Information to Enhance Channel-Spatial Interactions(GAM)

本文介绍了一个名为GAM_Attention的PyTorch模块,它结合了通道注意力和空间注意力,用于图像处理任务。通过线性MLP和卷积操作实现特征融合,适用于ImageNet-1k实验。核心部分展示了如何构造和使用这个模块,以及其在输入数据上的操作过程。
摘要由CSDN通过智能技术生成

Codes of pytorch:

import torch.nn as nn  
import torch  


class GAM_Attention(nn.Module):  
    def __init__(self, in_channels, out_channels, rate=4):  
        super(GAM_Attention, self).__init__()  

        self.channel_attention = nn.Sequential(  
            nn.Linear(in_channels, int(in_channels / rate)),  
            nn.ReLU(inplace=True),  
            nn.Linear(int(in_channels / rate), in_channels)  ###通道注意力  MLP来实现
        )  
      
        self.spatial_attention = nn.Sequential(  
            nn.Conv2d(in_channels, int(in_channels / rate), kernel_size=7, padding=3),  
            nn.BatchNorm2d(int(in_channels / rate)),  
            nn.ReLU(inplace=True),  
            nn.Conv2d(int(in_channels / rate), out_channels, kernel_size=7, padding=3),  #空间注意力  卷积实现
            nn.BatchNorm2d(out_channels)  
        )  

    def forward(self, x):  
        b, c, h, w = x.shape  
        print("Input size:",x.shape)
        x_permute = x.permute(0, 2, 3, 1).view(b, -1, c)  #(b,c,h*w)
        print("维度转换:x_permute = x.permute(0, 2, 3, 1).view(b, -1, c):",x_permute.shape)
        x_att_permute = self.channel_attention(x_permute).view(b, h, w, c)    #(b,h,w,c)
        
        x_channel_att = x_att_permute.permute(0, 3, 1, 2)  #(b,c,h,w)
        print("送入通道注意力子模块然后恢复到原始的size以便于x计算通道注意力:x_att_permute = self.channel_attention(x_permute).view(b, h, w, c).permute(0, 3, 1, 2):",x_channel_att.shape)
      
        x = x * x_channel_att  ###计算通道注意力
        print("Get channel attention map:",x.shape)
      
        x_spatial_att = self.spatial_attention(x).sigmoid()  
        print("把通道注意力图送入空间注意力子模块:",x_spatial_att.shape)
        out = x * x_spatial_att  
        print("得到的空间通道注意力图与通道注意力图点乘得到最后的GAM注意力图:",out.shape)
      
        return out  

  

if __name__ == '__main__':  
    x = torch.randn(1, 64, 32, 48)  
    b, c, h, w = x.shape  
    net = GAM_Attention(in_channels=c, out_channels=c)  
    y = net(x)  

code results:
在这里插入图片描述

Title and authors:
在这里插入图片描述
paper address:
https://arxiv.org/pdf/2112.05561v1.pdf

Overview of GAM:
在这里插入图片描述
Channel and Spatial attention submodule
在这里插入图片描述
Experiment in ImageNet-1k
在这里插入图片描述

  • 1
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值