即插即用的涨点模块之注意力机制(SEAttention)详解及代码,可应用于检测、分割、分类等各种算法领域

目录

前言

一、SENet结构

二、SENet计算流程

三、SENet参数

四、代码讲解 


前言

Squeeze-and-Excitation Networks(SENet)

来源:CVPR2018

官方代码:GitHub - hujie-frank/SENet: Squeeze-and-Excitation Networks

        什么是通道特征?通道特征(Channel Features)是指卷积神经网络(CNN)中每个卷积核产生的输出。一个通道对应于网络中的一个卷积核,而每个通道的输出表示该卷积核在输入上的响应。通道特征捕捉了输入数据中不同方面的抽象信息。每个通道对应于某种特定的抽象特征,例如纹理、颜色、边缘等。通道特征在整个网络中负责提取和表示不同层次的信息。

        什么是通道注意力机制?通道注意力机制(Channel Attention Mechanism)是深度学习中一种用于增强通道特征捕捉能力的注意力机制。它主要应用于卷积神经网络(CNN)中,以提高模型对不同通道(channel)的特征的关注度,从而使网络更加有效地学习和利用输入数据的信息。在通道注意力机制中,通过学习每个通道的权重,模型可以在处理特定通道的特征时给予更多的注意力。这有助于网络在学习过程中更好地区分不同通道的重要性,从而提高模型对输入数据的表示能力。


一、SENet结构

SENet是一种通道注意力机制,结构如图1所示。SE注意力模块,由Squeeze操作、Excitation操作、Scale操作三部分组成。Squeeze操作:对输入的特征图进行全局平均池化,将每个通道的特征值降维为一个全局向量。这一步旨在捕捉每个通道的全局信息。Excitation操作:由两个全连接和一个ReLU激活函数、一个Softmax激活函数组成,先进行降维在升维,最后通过sigmoid函数生成权重向量,确保它们的总和为1。Scale操作:将上一步得到的通道注意力权重乘以输入的原始特征图。这一步用于调整每个通道的特征值,强调重要通道的信息,抑制不重要通道的信息。SE注意力模块与Inception、ResNet的结合,分别如图2、图3所示。

图1 SENet block

图2 原始 Inception 模块(左)和 SE-Inception 模块(右)的架构。

图3 原始 Residual 模块(左)和 SE-ResNet 模块(右)的架构

二、SENet计算流程

        如图1所示,给定一个输入X∈H'×W'×C' ,通过一个卷积变化Ftr 得到 UH×W×C 将特征U 经过Squeeze操作Fsq 在空间维度H×W 上聚合特征得到Z∈1×1×C 。接下来进行Excitation操作,得到s。其中W1 、W2 为全连接层后的权重,δ 为ReLU函数,σ 为Sigmoid函数。

s=Fex(z,W)=σ(g(z,W))=σ(W2δ(W1z))

        最后进行Scale操作,将特征U和特征s相乘得到x,通过相乘,注意力机制可以对每个通道进行更细粒度的权重调整,将更多的注意力集中在对任务更为关键的通道上。从而调整该特征值的重要性。这个过程使得网络更加关注那些在给定任务中对应通道上重要的特征。

x=Fscale (U,s)=Us

三、SENet参数

利用thop库的profile函数计算FLOPs和Param。Input:(512,7,7)。

ModuleFLOPsParam
SEAttention9113665536

四、代码讲解 

import torch
from torch import nn
from torch.nn import init


class SEAttention(nn.Module):

    def __init__(self, channel=512, reduction=16):
        super().__init__()
        # 在空间维度上,将H×W压缩为1×1
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        # 包含两层全连接,先降维,后升维。最后接一个sigmoid函数
        self.fc = nn.Sequential(
            nn.Linear(channel, channel // reduction, bias=False),
            nn.ReLU(inplace=True),
            nn.Linear(channel // reduction, channel, bias=False),
            nn.Sigmoid()
        )

    def init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                init.kaiming_normal_(m.weight, mode='fan_out')
                if m.bias is not None:
                    init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                init.constant_(m.weight, 1)
                init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                init.normal_(m.weight, std=0.001)
                if m.bias is not None:
                    init.constant_(m.bias, 0)

    def forward(self, x):
        # (B,C,H,W)
        B, C, H, W = x.size()
        # Squeeze: (B,C,H,W)-->avg_pool-->(B,C,1,1)-->view-->(B,C)
        y = self.avg_pool(x).view(B, C)
        # Excitation: (B,C)-->fc-->(B,C)-->(B, C, 1, 1)
        y = self.fc(y).view(B, C, 1, 1)
        # scale: (B,C,H,W) * (B, C, 1, 1) == (B,C,H,W)
        out = x * y
        return out




if __name__ == '__main__':
    from  torchsummary import summary
    from thop import profile
    model = SEAttention(channel=512, reduction=8)
    summary(model, (512, 7, 7), device='cpu')
    flops, params = profile(model, inputs=(torch.randn(1, 512, 7, 7),))
    print(f"FLOPs: {flops}, Params: {params}")

  • 23
    点赞
  • 29
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
注意力机制是一种常用的机制,用于在给定一组输入和一个查询时,计算输入中每个元素对于查询的重要性或相关性。对于机器学习任务,特别是自然语言处理任务,注意力机制被广泛应用于序列到序列的模型、机器翻译、文本摘要等任务中。 在以下代码中,我将为你详细解释一种常见的注意力机制:Scaled Dot-Product Attention。 ```python import torch import torch.nn as nn class ScaledDotProductAttention(nn.Module): def __init__(self): super(ScaledDotProductAttention, self).__init__() def forward(self, query, key, value): # 计算注意力得分 scores = torch.matmul(query, key.transpose(-2, -1)) scores = scores / torch.sqrt(query.size(-1)) # 使用softmax函数进行归一化 attention_weights = torch.softmax(scores, dim=-1) # 对value进行加权求和 output = torch.matmul(attention_weights, value) return output, attention_weights ``` 在这段代码中,`ScaledDotProductAttention` 类继承自 `nn.Module`,并实现了 `forward` 方法。该方法接受三个输入参数:`query`、`key` 和 `value`。这里的 `query` 表示查询向量,`key` 表示键向量,`value` 表示值向量。 在 `forward` 方法中,首先通过矩阵乘法计算注意力得分。这里使用了 `torch.matmul` 函数,将 `query` 和 `key` 进行矩阵乘法操作,得到一个注意力得分矩阵。为了缩放注意力得分,我们将其除以查询的维度的平方根。 接下来,通过 `torch.softmax` 函数对注意力得分进行归一化处理,得到注意力权重矩阵。注意力权重矩阵表示每个键向量对于查询向量的重要性或相关性。 最后,将注意力权重矩阵与值向量进行加权求和,得到最终的输出。这里使用 `torch.matmul` 函数来实现加权求和。 这就是一个简单的Scaled Dot-Product Attention注意力机制代码实现。在实际应用中,注意力机制可能会有更多的变体和扩展,以适应不同的任务和模型架构。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值