前言
特征融合是深度学习模型设计中提升表达能力的关键步骤,主要有三种基础方法:逐元素相加、直接乘积、通道维度拼接,从我个人的使用角度来看,Concat要更加好用。我觉得相加和乘积属于是不可逆的操作,直接融合会导致原有信息丢失,无法再分离,如果两个特征图的响应模式差异较大,这两种方式会引入噪声冗余信息。在通道维度拼接(如torch.cat(dim=1))则会保留所有原始特征,只不过Concat会增加参数量,所以我们需要对其进行降维,本篇将会讲解一写关于卷积层和线性层两种降维方式。
通道拼接与降维
当使用Concat进行通道拼接时,假设两个输入特征的通道数分别为C1和C2,则拼接后通道数变为了C1+C2,若直接输入到后续层,会导致维度不匹配。
比如下面:
在多任务学习或特征融合场景中,我们常需要将不同来源的特征进行拼接(如高低层特征融合)。但拼接后的特征通道数会成倍增加(如 128 + 128 = 256),导致后续计算量激增。
import torch
import torch.nn as nn
import torch.nn.functional as F
if __name__ == '__main__':
h_feat = torch.rand(1, 128, 512, 512)
l_feat = torch.rand(1, 128, 256, 256)
upsampled_l_feat = F.interpolate(
l_feat,
size=(512, 512),
mode='bilinear',
align_corners=False
)
merged_feat = torch.cat([h_feat, upsampled_l_feat], dim=1)
print("拼接特征:", merged_feat.shape) # torch.Size([1, 256, 512, 512])
这里对低分辨率图像进行了上采样,使其与高分辨率图像进行拼接。最后输出的通道维度变为了原来的两倍。
下面进行降维,这里提供了两种降维方法,线性降维与卷积降维,线性降维通过 permute 将通道维度移动到最后一个位置,转换为 [Batch, H, W, C],应用全连接层后,再将形状还原为 [Batch, C, H, W],适合需要全局信息融合的场景;卷积降维直接应用 1x1 卷积,保持空间分辨率(H,W 不变),仅改变通道数,适合需要保留空间信息的场景。
import torch
import torch.nn as nn
class DimReduction(nn.Module):
def __init__(self, dim, height=2):
super(DimReduction, self).__init__()
self.liner_fuse = nn.Linear(dim * height, dim)
self.conv_fuse = nn.Sequential(
nn.Conv2d(dim * height, dim, 1, stride=1),
nn.BatchNorm2d(dim),
nn.ReLU(inplace=True)
)
def forward(self, merged_feat, use_conv=True):
if use_conv:
output = self.conv_fuse(merged_feat)
else:
trans_merged_feat = merged_feat.permute(0, 2, 3, 1) # B C H W -> B H W C
output = self.liner_fuse(trans_merged_feat)
output = output.permute(0, 3, 1, 2) # B H W C -> B C H W
return output
if __name__ == '__main__':
h_feat = torch.rand(1, 128, 512, 512)
l_feat = torch.rand(1, 128, 512, 512)
merged_feat = torch.cat([h_feat, l_feat], dim=1)
print("输入特征:", merged_feat.shape) # torch.Size([1, 256, 512, 512])
dim_reduce = DimReduction(128, 2)
reduce_conv = dim_reduce(merged_feat, use_conv=True)
print("卷积层降维:", reduce_conv.shape)
reduce_liner = dim_reduce(merged_feat, use_conv=False)
print("线性层降维:", reduce_liner.shape)
降维后的特征图通道数减少,但保留了重要的特征信息,这为后续的注意力机制提供了更紧凑的特征表示,减少了计算量。
你可以在其添加通道注意力机制比如SE,适用于强调重要通道的特征,适用于降维后的特征图,因为它可以帮助模型聚焦于关键的通道信息。空间注意力比如CBAM中的空间模块,适用于突出重要空间区域的特征,适用于任何分辨率的特征图,包括降维后的结果。混合注意力比如CBAM、BAM,适合需要同时考虑通道和空间依赖关系的场景,还有自注意力比如Transformer中的注意力,适用于捕捉长距离依赖关系。
总结
关于使用哪一种能够涨点,这就是你自己的实验该做的了,没有哪一种模块是能够包能涨点的,就是很多当下的论文提出的特征融合结构,用在自己的数据集上也不一定会有效果。在缝合别人的模块的时候要从原理上,数据集上还有你自己的网络结构位置上来分析,而能否有效就要靠自己的实验是否有增长。