🔥博客主页: A_SHOWY
🎥系列专栏:力扣刷题总结录 数据结构 云计算 数字图像处理 力扣每日一题_
总体介绍
CBAM(Convolutional Block Attention Module)是一种用于增强卷积神经网络(CNN)性能的注意力机制模块。它CBAM的主要目标是通过在CNN中引入通道注意力和空间注意力来提高模型的感知能力,从而在不增加网络复杂性的情况下改善性能。
CBAM模块由两个注意力模块组成:通道注意力模块(Channel Attention Module)和空间注意力模块(Spatial Attention Module)。
通道注意力机制
整体过程是先传入这个图像特征,分别通过全局最大池化和全局平均池化得到两个特征长条,并通过两层全连接层来学习通道的权重。然后,会将处理后产生的两个结果进行相加,通过使Sigmoid函数将权重归一化到0到1之间,对每个通道进行缩放。最后,将缩放后的通道特征与原始特征相乘,以产生具有增强通道重要性的特征。
空间注意力机制
寻找通道上所有特征点的最大值和平均值(就是通道的最大值和平均值),得到两个matrix后,将两个matrix进行链接,并通过一个卷积层和Sigmoid函数来学习每个空间位置的权重。最后,将权重应用于特征图上的每个空间位置,以产生具有增强空间重要性的特征。
代码实现
import torch
from torch import nn
#通道注意力机制
class channel_attention(nn.Module):
def __init__(self,channel,ratio = 16):
super(channel_attention,self).__init__()
#最大池化贺平均池化,输出层的高贺宽都是1
self.max_pool = nn.AdaptiveMaxPool2d(1)
self.ave_pool = nn.AdaptiveAvgPool2d(1)
#两次全连接
self.fc =nn.Sequential(
nn.Linear(channel,channel // ratio,False),
nn.ReLU(),
nn.Linear(channel // ratio, channel,False)
)
self.sigmoid = nn.Sigmoid()
def forward(self,x):
#获得输入的x的batch,通道数,高和宽
b,c,h,w = x.size()
max_pool_out = self.max_pool(x).view([b,c])
avg_pool_out = self.ave_pool(x).view([b,c])
max_fc_out = self.fc(max_pool_out)
ave_fc_out = self.fc(avg_pool_out)
out = max_fc_out + ave_fc_out
out = self.sigmoid(out).view([b,c,1,1])
return out * x
#空间注意力机制
class spacial_attention(nn.Module):
def __init__(self,kernel = 7):
super(spacial_attention, self).__init__()
self.conv = nn.Conv2d(2,1,kernel,1,padding = 7//2,bias = False)
self.sigmoid = nn.Sigmoid()
def forward(self,x):
b, c, h, w = x.size()
#寻找通道上所有特征点的最大值和平均值
max_pool_out,_ = torch.max(x,dim = 1,keepdim = True)
mean_pool_out = torch.mean(x,dim = 1,keepdim= True)
#链接
pool_out =torch.cat([max_pool_out,mean_pool_out],dim=1)
out = self.conv(pool_out)
#相当于获得每个特征点的权值
out = self.sigmoid(out)
return out * x
class Cbam(nn.Module):
def __init__(self,channel,kernel = 7,ratio = 16):
super(Cbam,self).__init__()
self.channel_attention = channel_attention(channel,ratio)
self.spacial_attention = spacial_attention(kernel)
def forward(self,x):
x = self.channel_attention(x)
x = self.spacial_attention(x)
return x
model = Cbam(512)
print(model)
inputs = torch.ones([2,512,26,26])
outputs = model(inputs)