ECCV 2018
1 介绍
- 在空间维度、和通道维度上进行attention操作
2 整体模型
- CBAM包含CAM(Channel Attention Module)和SAM(Spartial Attention Module)两个子模块,分别进行通道和空间上的Attention。
- 输入特征维度是C*H*W
- 通道注意力模块维度是C*1*1
- 逐元素乘到input feature中去(逐元素+广播)
- 空间注意力模块维度是1*H*W
- 逐元素乘到经过通道注意力之后的output feature中去(逐元素+广播)
2.1 Channel attention module(CAM)
- 通道注意力模块:通道维度不变,压缩空间维度。
- 该模块关注输入图片中有意义的信息(不同channel中有不同的信息)
- 将输入的feature map经过两个并行的MaxPool层和AvgPool层
- 特征图从C*H*W变为C*1*1的大小
- 然后经过Share MLP模块
- 该模块先将通道数压缩为原来的1/r(Reduction,减少率)倍,再扩张到原通道数
- 经过ReLU激活函数得到两个激活后的结果。
- 将这两个输出结果逐元素相加
- 通过一个sigmoid激活函数得到Channel Attention的输出结果(C*1*1维)
- 再将这个输出结果乘input feature【C*H*W的大小】
2.1.2 和SENet的区别
CAM与SEnet的不同之处是加了一个并行的最大池化层,提取到的高层特征更全面,更丰富。
2.2 Spatial attention module
- 空间注意力模块:空间维度不变,压缩通道维度。该模块关注的是目标的位置信息。
- 将Channel Attention的输出结果通过最大池化和平均池化得到两个1*H*W的特征图
- 然后经过Concat操作对两个特征图进行拼接
- 通过7*7卷积变为1通道的特征图(实验证明7*7效果比3*3好)
- 再经过一个sigmoid得到Spatial Attention的特征图
- 最后将输出结果乘原图变回C*H*W大小。
3 实验部分
3.1 图像分类结果
在数据集ImageNet-1K上使用ResNet网络进行对比实验
3.2 CBAM可视化
引入 CBAM 后,特征覆盖到了待识别物体的更多部位,而且最终判别物体的几率也更高,这表示注意力机制的确让网络学会了关注重点信息。
4 代码实现
#通道方向的自注意力
class ChannelAttention(nn.Module):
def __init__(self, in_planes, ratio=16):
super(ChannelAttention, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
#平均池化
self.max_pool = nn.AdaptiveMaxPool2d(1)
#最大池化
self.fc1 = nn.Conv2d(in_planes, in_planes // 16, 1, bias=False) #kernel_size=1
self.relu1 = nn.ReLU()
self.fc2 = nn.Conv2d(in_planes // 16, in_planes, 1, bias=False)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
avg_out = self.fc2(self.relu1(self.fc1(self.avg_pool(x))))
max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x))))
'''
先分别进行平均池化/最大池化
——>维度从(C,H,W)变至(C,1,1)
然后通过几层全连接,得到新的channel-wise的(C,1,1)张量
'''
#结果相加
out = avg_out + max_out
return self.sigmoid(out)
#通过sigmoid得到通道attention
#空间注意力
class SpatialAttention(nn.Module):
def __init__(self, kernel_size=7):
super(SpatialAttention, self).__init__()
assert kernel_size in (3, 7), 'kernel size must be 3 or 7'
#声明卷积核为 3 或 7
padding = 3 if kernel_size == 7 else 1
#进行相应的same padding填充
self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
avg_out = torch.mean(x, dim=1, keepdim=True)
#平均池化
max_out, _ = torch.max(x, dim=1, keepdim=True)
#最大池化
x = torch.cat([avg_out, max_out], dim=1)
#拼接操作
x = self.conv1(x)
#7x7卷积填充为3,输入通道为2,输出通道为1
return self.sigmoid(x)
#通过sigmoid得到空间attention