(即插即用模块-Attention部分) 七、(WACV 2021) Triplet Attention 三重注意力

在这里插入图片描述

paper:Rotate to Attend: Convolutional Triplet Attention Module

Code:https://github.com/LandskapeAI/triplet-attention


1、Triplet Attention

论文首先分析了 CBAM 和 SENet 存在的一些问题,指出这两个注意力都需要一定数量的可学习参数来建立通道间的相互依赖关系。该过程的缺点在于通道注意力和空间注意力是彼此独立地分离和计算的。

基于此,这篇论文提出一种三重注意力(Triplet Attention),其通过捕获交叉维度的相互作用来计算注意力的权重。对于一个输入张量,三元组注意力通过旋转操作建立维度间依赖关系,然后进行残差变换,并以可忽略的计算开销编码通道间和空间信息。与之前的注意力工作不同的是,Triplet Attention引入了跨维相互作用,强调了跨维度交互的重要性。没有降维的过程,从而消除了通道和权重之间的间接对应关系。

三重注意力,顾名思义,由三个平行分支组成,其中两个负责捕获通道维度C与空间维度H或W之间的跨维度相互作用。剩下的最后一个分支类似于CBAM,用于建立空间注意力。所有三个分支的输出都使用简单平均值进行聚合。


Triplet Attention 与SENet、CBAM、GC的比较

在这里插入图片描述


Triplet Attention 的具体操作:对于给定的一个输入张量 X,首先将其传递到所提出的三重注意力模块中的三个分支中的每一个。

  1. 第一个分支:通道注意力计算分支,输入特征 X 经过Z-Pool,再接着7 x 7卷积,最后Sigmoid激活函数生成通道注意力权重。
  2. 第二个分支:通道C和空间W维度交互捕获分支,输入特征先经过permute变换维度特征,接着在H维度上进行Z-Pool,后面操作类似。最后需要经过permuter变为原维度特征,方便进行element-wise相加。
  3. 第三个分支:通道C和空间H维度交互捕获分支,输入特征先经过permute,变换维度特征,接着在W维度上进行Z-Pool,后面操作类似。最后需要经过permuter变为原维度特征,方便进行element-wise相加。
  4. 最后对3个分支输出特征进行相加 获得 最终输出 注意力图。

Triplet Attention 结构图:

在这里插入图片描述


2、代码实现

import torch
import torch.nn as nn


class BasicConv(nn.Module):
    def __init__(self,in_planes,out_planes,kernel_size,stride=1,padding=0,dilation=1,
                 groups=1,bias=False,):
        super(BasicConv, self).__init__()
        self.out_channels = out_planes
        self.conv = nn.Conv2d(in_planes,out_planes,kernel_size=kernel_size,stride=stride,
                              padding=padding,dilation=dilation,groups=groups,bias=bias,)
        self.bn = nn.BatchNorm2d(out_planes, eps=1e-5, momentum=0.01, affine=True)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = self.relu(x)
        return x


class ChannelPool(nn.Module):
    def forward(self, x):
        return torch.cat((torch.max(x, 1)[0].unsqueeze(1), torch.mean(x, 1).unsqueeze(1)), dim=1)


class SpatialGate(nn.Module):
    def __init__(self):
        super(SpatialGate, self).__init__()
        kernel_size = 7
        self.compress = ChannelPool()
        self.spatial = BasicConv(2, 1, kernel_size, stride=1, padding=(kernel_size - 1) // 2)

    def forward(self, x):
        x_compress = self.compress(x)
        x_out = self.spatial(x_compress)
        scale = torch.sigmoid_(x_out)
        return x * scale


class TripletAttention(nn.Module):
    def __init__(
        self,
        gate_channels,
        reduction_ratio=16,
        pool_types=["avg", "max"],
        no_spatial=False,
    ):
        super(TripletAttention, self).__init__()
        self.ChannelGateH = SpatialGate()
        self.ChannelGateW = SpatialGate()
        self.no_spatial = no_spatial
        if not no_spatial:
            self.SpatialGate = SpatialGate()

    def forward(self, x):
        x_perm1 = x.permute(0, 2, 1, 3).contiguous()
        x_out1 = self.ChannelGateH(x_perm1)
        x_out11 = x_out1.permute(0, 2, 1, 3).contiguous()
        x_perm2 = x.permute(0, 3, 2, 1).contiguous()
        x_out2 = self.ChannelGateW(x_perm2)
        x_out21 = x_out2.permute(0, 3, 2, 1).contiguous()
        if not self.no_spatial:
            x_out = self.SpatialGate(x)
            x_out = (1 / 3) * (x_out + x_out11 + x_out21)
        else:
            x_out = (1 / 2) * (x_out11 + x_out21)
        return x_out


if __name__ == '__main__':
    x = torch.randn(4, 512, 7, 7).cuda()
    model = TripletAttention(512).cuda()
    out = model(x)
    print(out.shape)

本文只是对论文中的即插即用模块做了整合,对论文中的一些地方难免有遗漏之处,如果想对这些模块有更详细的了解,还是要去读一下原论文,肯定会有更多收获。

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

御宇w

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值