1、Cross-Attention Fusion Module
在现有的CNN 和 GCN中,存在一些局限性: 即虽然CNN和GCN都能有效提取特征,但它们分别侧重于像素级和超像素级信息。直接融合两者得到的特征往往不够充分,难以有效提升分类性能。此外,现有融合方法也存在着一些不足: 现有的融合方法大多采用简单的加权融合,缺乏对特征重要性的考虑,无法有效地突出重要特征,导致融合效果不佳。所以这篇论文提出一种 交叉注意力融合模块(Cross-Attention Fusion Module)。
CAFM 的基本原理是通过交叉注意力机制,将 PMCsN 和 MGCsN 提取的特征进行交互和融合,以获得更具判别力的特征。
CAFM 包含两个部分:通道注意力交叉模块和空间注意力融合模块。其具体实现过程如下:
-
通道注意力交叉模块:首先对两个子网络的特征分别进行全局最大池化和平均池化,得到两个通道描述。其中,全局最大池化操作会提取每个通道的最大值,而全局平均池化操作会提取每个通道的平均值,从而分别得到两个不同的通道描述。
然后将两个通道描述输入到一个共享的两层神经网络,该神经网络包含一个 ReLU 激活函数。通过两层神经网络,得到两个通道权重系数。再将两个通道权重系数相乘,得到一个交叉矩阵。最后将交叉矩阵分别与两个子网络的特征相乘,得到融合后的通道特征。
-
空间注意力融合模块:在空间层面,首先对两个子网络的特征分别进行最大池化和平均池化,得到两个空间描述。最大池化操作会提取每个像素的最大值,而平均池化操作会提取每个像素的平均值,从而分别得到两个不同的空间描述。
然后将两个空间描述在通道维度进行拼接,得到一个新的特征图。再将拼接后的特征图输入到一个共享的卷积层。再用一个卷积层学习空间特征,并得到空间权重系数。最后将空间权重系数分别与两个子网络的特征相乘,得到融合后的空间特征。
-
残差连接:最后将融合后的特征与输入特征进行残差连接,得到最终的融合特征。残差连接可以增强网络的鲁棒性,并有助于网络学习更深层次的特征。
Cross-Attention Fusion Module 结构图:
2、代码实现
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops.einops import rearrange
class CAFM(nn.Module): # Cross Attention Fusion Module
def __init__(self, channels):
super(CAFM, self).__init__()
self.conv1_spatial = nn.Conv2d(2, 1, 3, stride=1, padding=1, groups=1)
self.conv2_spatial = nn.Conv2d(1, 1, 3, stride=1, padding=1, groups=1)
self.avg1 = nn.Conv2d(channels, 64, 1, stride=1, padding=0)
self.avg2 = nn.Conv2d(channels, 64, 1, stride=1, padding=0)
self.max1 = nn.Conv2d(channels, 64, 1, stride=1, padding=0)
self.max2 = nn.Conv2d(channels, 64, 1, stride=1, padding=0)
self.avg11 = nn.Conv2d(64, channels, 1, stride=1, padding=0)
self.avg22 = nn.Conv2d(64, channels, 1, stride=1, padding=0)
self.max11 = nn.Conv2d(64, channels, 1, stride=1, padding=0)
self.max22 = nn.Conv2d(64, channels, 1, stride=1, padding=0)
def forward(self, f1, f2):
b, c, h, w = f1.size()
f1 = f1.reshape([b, c, -1])
f2 = f2.reshape([b, c, -1])
avg_1 = torch.mean(f1, dim=-1, keepdim=True).unsqueeze(-1)
max_1, _ = torch.max(f1, dim=-1, keepdim=True)
max_1 = max_1.unsqueeze(-1)
avg_1 = F.relu(self.avg1(avg_1))
max_1 = F.relu(self.max1(max_1))
avg_1 = self.avg11(avg_1).squeeze(-1)
max_1 = self.max11(max_1).squeeze(-1)
a1 = avg_1 + max_1
avg_2 = torch.mean(f2, dim=-1, keepdim=True).unsqueeze(-1)
max_2, _ = torch.max(f2, dim=-1, keepdim=True)
max_2 = max_2.unsqueeze(-1)
avg_2 = F.relu(self.avg2(avg_2))
max_2 = F.relu(self.max2(max_2))
avg_2 = self.avg22(avg_2).squeeze(-1)
max_2 = self.max22(max_2).squeeze(-1)
a2 = avg_2 + max_2
cross = torch.matmul(a1, a2.transpose(1, 2))
a1 = torch.matmul(F.softmax(cross, dim=-1), f1)
a2 = torch.matmul(F.softmax(cross.transpose(1, 2), dim=-1), f2)
a1 = a1.reshape([b, c, h, w])
avg_out = torch.mean(a1, dim=1, keepdim=True)
max_out, _ = torch.max(a1, dim=1, keepdim=True)
a1 = torch.cat([avg_out, max_out], dim=1)
a1 = F.relu(self.conv1_spatial(a1))
a1 = self.conv2_spatial(a1)
a1 = a1.reshape([b, 1, -1])
a1 = F.softmax(a1, dim=-1)
a2 = a2.reshape([b, c, h, w])
avg_out = torch.mean(a2, dim=1, keepdim=True)
max_out, _ = torch.max(a2, dim=1, keepdim=True)
a2 = torch.cat([avg_out, max_out], dim=1)
a2 = F.relu(self.conv1_spatial(a2))
a2 = self.conv2_spatial(a2)
a2 = a2.reshape([b, 1, -1])
a2 = F.softmax(a2, dim=-1)
f1 = f1 * a1 + f1
f2 = f2 * a2 + f2
f1 = f1.squeeze(0)
f2 = f2.squeeze(0)
return f1.transpose(0, 1), f2.transpose(0, 1)
if __name__ == '__main__':
"""
本来CAFM的输入通道是固定的128,我在这里加了个参数
CAFM 的结果有两个,并且维度顺序是乱的,可以先相加,再调维度顺序
"""
H, W = 7, 7
x = torch.randn(4, 512, 7, 7).cuda()
y = torch.randn(4, 512, 7, 7).cuda()
model = CAFM(512).cuda()
out_1,out_2 = model(x,y)
out = out_1 + out_2
out = out.permute(1, 2, 0)
out = rearrange(out, 'b (h w) c -> b c h w', h=H, w=W)
print(out.shape)