(即插即用模块-特征处理部分) 十三、(TGRS 2023) CAFM 交叉注意力融合模块

在这里插入图片描述

paper:Attention Multihop Graph and Multiscale Convolutional Fusion Network for Hyperspectral Image Classification

Code:https://github.com/EdwardHaoz/IEEE_TGRS_AMGCFN


1、Cross-Attention Fusion Module

在现有的CNN 和 GCN中,存在一些局限性: 即虽然CNN和GCN都能有效提取特征,但它们分别侧重于像素级和超像素级信息。直接融合两者得到的特征往往不够充分,难以有效提升分类性能。此外,现有融合方法也存在着一些不足: 现有的融合方法大多采用简单的加权融合,缺乏对特征重要性的考虑,无法有效地突出重要特征,导致融合效果不佳。所以这篇论文提出一种 交叉注意力融合模块(Cross-Attention Fusion Module)

CAFM 的基本原理是通过交叉注意力机制,将 PMCsN 和 MGCsN 提取的特征进行交互和融合,以获得更具判别力的特征。

CAFM 包含两个部分:通道注意力交叉模块和空间注意力融合模块。其具体实现过程如下:

  1. 通道注意力交叉模块:首先对两个子网络的特征分别进行全局最大池化和平均池化,得到两个通道描述。其中,全局最大池化操作会提取每个通道的最大值,而全局平均池化操作会提取每个通道的平均值,从而分别得到两个不同的通道描述。

    然后将两个通道描述输入到一个共享的两层神经网络,该神经网络包含一个 ReLU 激活函数。通过两层神经网络,得到两个通道权重系数。再将两个通道权重系数相乘,得到一个交叉矩阵。最后将交叉矩阵分别与两个子网络的特征相乘,得到融合后的通道特征。

  2. 空间注意力融合模块:在空间层面,首先对两个子网络的特征分别进行最大池化和平均池化,得到两个空间描述。最大池化操作会提取每个像素的最大值,而平均池化操作会提取每个像素的平均值,从而分别得到两个不同的空间描述。

    然后将两个空间描述在通道维度进行拼接,得到一个新的特征图。再将拼接后的特征图输入到一个共享的卷积层。再用一个卷积层学习空间特征,并得到空间权重系数。最后将空间权重系数分别与两个子网络的特征相乘,得到融合后的空间特征。

  3. 残差连接:最后将融合后的特征与输入特征进行残差连接,得到最终的融合特征。残差连接可以增强网络的鲁棒性,并有助于网络学习更深层次的特征。


Cross-Attention Fusion Module 结构图:
在这里插入图片描述


2、代码实现

import torch
import torch.nn as nn
import torch.nn.functional as F
from einops.einops import rearrange


class CAFM(nn.Module):  # Cross Attention Fusion Module
    def __init__(self, channels):
        super(CAFM, self).__init__()

        self.conv1_spatial = nn.Conv2d(2, 1, 3, stride=1, padding=1, groups=1)
        self.conv2_spatial = nn.Conv2d(1, 1, 3, stride=1, padding=1, groups=1)

        self.avg1 = nn.Conv2d(channels, 64, 1, stride=1, padding=0)
        self.avg2 = nn.Conv2d(channels, 64, 1, stride=1, padding=0)
        self.max1 = nn.Conv2d(channels, 64, 1, stride=1, padding=0)
        self.max2 = nn.Conv2d(channels, 64, 1, stride=1, padding=0)

        self.avg11 = nn.Conv2d(64, channels, 1, stride=1, padding=0)
        self.avg22 = nn.Conv2d(64, channels, 1, stride=1, padding=0)
        self.max11 = nn.Conv2d(64, channels, 1, stride=1, padding=0)
        self.max22 = nn.Conv2d(64, channels, 1, stride=1, padding=0)

    def forward(self, f1, f2):
        b, c, h, w = f1.size()

        f1 = f1.reshape([b, c, -1])
        f2 = f2.reshape([b, c, -1])

        avg_1 = torch.mean(f1, dim=-1, keepdim=True).unsqueeze(-1)
        max_1, _ = torch.max(f1, dim=-1, keepdim=True)
        max_1 = max_1.unsqueeze(-1)

        avg_1 = F.relu(self.avg1(avg_1))
        max_1 = F.relu(self.max1(max_1))
        avg_1 = self.avg11(avg_1).squeeze(-1)
        max_1 = self.max11(max_1).squeeze(-1)
        a1 = avg_1 + max_1

        avg_2 = torch.mean(f2, dim=-1, keepdim=True).unsqueeze(-1)
        max_2, _ = torch.max(f2, dim=-1, keepdim=True)
        max_2 = max_2.unsqueeze(-1)

        avg_2 = F.relu(self.avg2(avg_2))
        max_2 = F.relu(self.max2(max_2))
        avg_2 = self.avg22(avg_2).squeeze(-1)
        max_2 = self.max22(max_2).squeeze(-1)
        a2 = avg_2 + max_2

        cross = torch.matmul(a1, a2.transpose(1, 2))

        a1 = torch.matmul(F.softmax(cross, dim=-1), f1)
        a2 = torch.matmul(F.softmax(cross.transpose(1, 2), dim=-1), f2)

        a1 = a1.reshape([b, c, h, w])
        avg_out = torch.mean(a1, dim=1, keepdim=True)
        max_out, _ = torch.max(a1, dim=1, keepdim=True)
        a1 = torch.cat([avg_out, max_out], dim=1)
        a1 = F.relu(self.conv1_spatial(a1))
        a1 = self.conv2_spatial(a1)
        a1 = a1.reshape([b, 1, -1])
        a1 = F.softmax(a1, dim=-1)

        a2 = a2.reshape([b, c, h, w])
        avg_out = torch.mean(a2, dim=1, keepdim=True)
        max_out, _ = torch.max(a2, dim=1, keepdim=True)
        a2 = torch.cat([avg_out, max_out], dim=1)
        a2 = F.relu(self.conv1_spatial(a2))
        a2 = self.conv2_spatial(a2)
        a2 = a2.reshape([b, 1, -1])
        a2 = F.softmax(a2, dim=-1)

        f1 = f1 * a1 + f1
        f2 = f2 * a2 + f2

        f1 = f1.squeeze(0)
        f2 = f2.squeeze(0)

        return f1.transpose(0, 1), f2.transpose(0, 1)


if __name__ == '__main__':
    """
    本来CAFM的输入通道是固定的128,我在这里加了个参数
    CAFM 的结果有两个,并且维度顺序是乱的,可以先相加,再调维度顺序
    """
    H, W = 7, 7
    x = torch.randn(4, 512, 7, 7).cuda()
    y = torch.randn(4, 512, 7, 7).cuda()
    model = CAFM(512).cuda()
    out_1,out_2 = model(x,y)

    out = out_1 + out_2
    out = out.permute(1, 2, 0)
    out = rearrange(out, 'b (h w) c -> b c h w', h=H, w=W)

    print(out.shape)

评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

御宇w

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值