-
SA(Spatial Attention)模块:
-
Spatial Attention (SA) 模块是一种用于增强卷积神经网络(CNN)中特征图空间信息的技术。它通过学习输入特征图的空间注意力图来强调重要的空间区域,从而帮助模型更专注于图像中的关键部分。以下是 SA 模块的工作原理及代码实现的详细解释:
### 工作原理
1. **全局最大池化**:对每个通道进行全局最大池化操作,得到一个代表每个位置上所有通道的最大值的特征图。这一步骤可以帮助捕捉到最显著的特征。
2. **全局平均池化**:对每个通道进行全局平均池化操作,得到一个代表每个位置上所有通道的平均值的特征图。这一步骤有助于捕捉到更为平滑的特征分布。
3. **特征融合**:将上述两个特征图在通道维度上拼接起来,形成一个新的多通道特征图。这个新的特征图包含了关于原始特征图的重要性和分布的信息。
4. **卷积操作**:使用一个 \(1 \times 1\) 卷积层(或本例中的 \(7 \times 7\) 卷积层,取决于 `kernel_size` 参数设置)来减少通道数至 1,同时学习空间上的注意力权重。这里不使用偏置项 (`bias=False`) 是为了简化模型。
5. **Sigmoid 激活函数**:最后,使用 Sigmoid 函数将卷积输出转换为范围在 [0, 1] 之间的注意力权重图。这些权重可以被用作后续层的加权因子,以增强或抑制某些空间区域的重要性。
### 代码实现
- **初始化**:`SpatialAttention` 类继承自 `nn.Module`,在构造函数中定义了卷积层和 Sigmoid 激活函数。卷积层的输入通道数为 2(因为我们将最大池化和平均池化的结果拼接在一起),输出通道数为 1(因为我们希望最终得到一个单通道的注意力图)。根据选择的 `kernel_size`,设置了相应的填充大小以保持输出尺寸不变。
- **前向传播**:在 `forward` 方法中,首先对输入张量 `x` 进行维度变换,以便进行全局池化操作。然后分别计算最大池化和平均池化的结果,并在通道维度上拼接这两个结果。接下来,通过卷积层处理拼接后的特征图,最后应用 Sigmoid 激活函数生成最终的注意力图。
- **测试**:在 `__main__` 块中,创建了一个随机输入张量 `data`,并通过实例化 `SpatialAttention` 类来测试模块的功能。输出张量的形状为 `[1, 1, 256, 256]`,表示经过空间注意力机制后,我们得到了一个与输入相同空间尺寸的单通道注意力图。
通过这种方式,SA 模块能够有效地提升 CNN 对于图像中重要区域的关注度,进而提高模型的性能。
-
-
SE(Squeeze-and-Excitation)模块:
-
工作原理
-
Squeeze(挤压):通过对输入特征图进行全局平均池化操作,将每个通道的空间信息压缩成一个单一的数值。这样,每个通道就只剩下一个标量值,表示该通道在整个特征图上的平均激活值。这一步骤的结果是一个形状为
[B, C, 1, 1]
的张量,其中B
是批次大小,C
是通道数。 -
Excitation(激励):使用两个全连接层(或等效的线性层)来建模通道间的依赖关系。首先,将挤压后的特征图通过一个降维的全连接层,将通道数减少到原来的
1/reduction_ratio
。接着,通过 ReLU 激活函数引入非线性。然后,再通过一个升维的全连接层恢复到原来的通道数。最后,使用 Sigmoid 激活函数将输出转换为范围在 [0, 1] 之间的注意力权重。 -
Scale(缩放):将得到的注意力权重与原始输入特征图逐元素相乘,实现对每个通道的特征响应的重新加权。这样,重要的通道会被放大,而不太重要的通道会被抑制。
-
代码实现
-
初始化:
input_channels
:输入特征图的通道数。reduction_ratio
:降维的比例,默认为 16。global_avg_pool
:全局平均池化层,用于将特征图尺寸压缩为[B, C, 1, 1]
。fc1
和fc2
:两个全连接层,分别用于降维和升维。relu
和sigmoid
:ReLU 激活函数和 Sigmoid 激活函数。
-
前向传播:
- 首先获取输入特征图的形状
[B, C, H, W]
。 - 使用全局平均池化层将特征图压缩为
[B, C, 1, 1]
。 - 将压缩后的特征图展平为
[B, 1, 1, C]
,以便通过全连接层。 - 通过第一个全连接层
fc1
进行降维,然后通过 ReLU 激活函数。 - 通过第二个全连接层
fc2
进行升维,然后通过 Sigmoid 激活函数,得到形状为[B, 1, 1, C]
的注意力权重。 - 将注意力权重展平为
[B, 1, H, W]
,并与原始输入特征图逐元素相乘,得到最终的输出特征图。
- 首先获取输入特征图的形状
-
这些模块主要用于增强网络对重要信息的关注度,从而提高模型性能。