一、CBAM
-
CBAM由通道注意力机制与空间注意力机制两部分组成
-
通道注意力机制模块如下:
-
首先特征层 Feature通过全局平均池化和全局最大池化,处理后的特征层通过共享全连接层分别生成Feature1和Feature2,最后对Feature1和Feature2进行相加和Sigmoid操作生成对应Feature每个特征层的(0—1)之间的权值,将channel attention 与Feature 进行相乘后完成CBAM通道注意力机制部分,得到新的特征Feature’
-
空间注意力机制模块如下:
-
首先将从通道注意力机制模块得到的进行Feature’ 进行 全局平均池化和全局最大池化并在channel维度拼接,将拼接后的特征层输入卷积核后再经过sigmoid操作,获得对应特征层每一个点(0—1)之间的权值,最后将Spatial attention 与原Feature’进行相乘得到最后的输出
-
N = (W - F + 2P) / S + 1 W:输入 F:卷积核 P:padding S:步距
-
代码中padding=kernel_size//2 是因为在S=1的情况下 这样可以设置可以保证输入与输出的Feature形状不变
二、CBAM的pytorch
2.1 通道注意力机制实现
class channel_attention(nn.Module):
def __init__(self, channel, ratio=16):
super(channel_attention, self).__init__()
self.max_pool = nn.AdaptiveMaxPool2d(1)
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Sequential(
nn.Linear(channel, channel // ratio, False),
nn.ReLU(),
nn.Linear(channel // ratio, channel, False)
)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
b, c, h, w = x.size()
max_pool__out = self.max_pool(x).view([b, c])
avg_pool__out = self.avg_pool(x).view([b, c])
max_fc_out = self.fc(max_pool__out)
avg_fc_out = self.fc(avg_pool__out)
out = avg_fc_out + max_fc_out
out = self.sigmoid(out).view([b, c, 1, 1])
return x * out
2.1 空间注意力机制实现
class spacial_attention(nn.Module):
def __init__(self, kernel_size=7):
super(spacial_attention, self).__init__()
# N = (W - F + 2P) / S + 1
padding=kernel_size//2
self.conv = nn.Conv2d(2,1,kernel_size,stride=1,padding=padding,bias=False)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
b, c, h, w = x.size()
max_pool__out = torch.max(x,dim=1)
mean_pool__out =torch.mean(x,dim=1)
pool_out=torch.cat([max_pool__out,mean_pool__out],dim=1)
out= self.conv(pool_out)
out=self.sigmoid(out)
return out
2.2 CBAM实现
class Cbam(nn.Module):
def __index__(self, channel,ratio=16,kernel_size=7):
super(Cbam,self).__init__()
self.channel_attention=channel_attention(channel_attention,channel,ratio)
self.spacial_attention=spacial_attention(channel_attention,kernel_size)
def forward(self,x):
x=self.channel_attention(x)
x=self.spacial_attention(x)
return x