SimAM注意力机制解析

先附上SimAM注意力机制代码

#SimAM
class SimAM(torch.nn.Module):
    def __init__(self, channels=None, out_channels=None, e_lambda=1e-4):
        super(SimAM, self).__init__()
 
        self.activaton = nn.Sigmoid()
        self.e_lambda = e_lambda
 
    def __repr__(self):
        s = self.__class__.__name__ + '('
        s += ('lambda=%f)' % self.e_lambda)
        return s
 
    @staticmethod
    def get_module_name():
        return "simam"
 
    def forward(self, x):
        b, c, h, w = x.size()
 
        n = w * h - 1
 
        x_minus_mu_square = (x - x.mean(dim=[2, 3], keepdim=True)).pow(2)
        y = x_minus_mu_square / (4 * (x_minus_mu_square.sum(dim=[2, 3], keepdim=True) / n + self.e_lambda)) + 0.5
 
        return x * self.activaton(y)

        这段代码定义了一个名为SimAM的PyTorch模块,它实现了SimAM(Similarity-Aware Activation Module)的功能。SimAM是一个注意力机制模块,旨在通过计算特征图之间的相似性来增强有用的特征。
        

  1. 类定义

    • class SimAM(torch.nn.Module)::定义了一个名为SimAM的类,它继承了PyTorch的nn.Module基类。
  2. 初始化方法

    • def __init__(self, channels=None, out_channels=None, e_lambda=1e-4)::类的构造函数,其中channelsout_channels参数在此实现中未被使用(可能是为了与其他注意力模块保持一致的接口),而e_lambda是一个小的常数,用于防止分母为零。
    • self.activaton = nn.Sigmoid():定义了一个sigmoid激活函数
    • self.e_lambda = e_lambda:保存了传入的e_lambda值。
  3. 字符串表示方法

    • def __repr__(self)::定义了一个方法,用于返回类的字符串表示。在这个方法中,它只返回了类名和e_lambda的值。
  4. 静态方法

    • @staticmethod:这是一个装饰器,表示下面的方法是静态方法。
    • def get_module_name()::返回模块的名称,即"simam"。但注意,由于这是静态方法,它应该有一个self参数或直接使用类名调用。但在此代码中,它作为实例方法也可以工作,但更合适的做法可能是作为类方法(即使用@classmethod装饰器)。
  5. 前向传播方法

    • def forward(self, x)::定义了前向传播方法,它接收一个输入张量x并返回处理后的张量。
    • b, c, h, w = x.size():获取输入张量x的维度,其中b是批处理大小,c是通道数,hw是高和宽。
    • n = w * h - 1:计算除中心像素外的像素数量(但这里减1的逻辑可能取决于具体的应用场景)。
    • x_minus_mu_square:计算每个通道中每个像素与通道均值的差的平方。
    • y:计算了一个加权值,其中分母是x_minus_mu_square的和(除以n并加上e_lambda以避免除以零),然后整个表达式再除以4并加0.5。这个计算可能是为了得到一个介于0和1之间的值,以便后续的sigmoid激活。
    • return x * self.activaton(y):使用sigmoid激活函数对y进行激活,并将结果与原始输入x相乘,得到最终的输出。

        其中,在前向传播中计算x_minus_mu_square时,x.mean(dim=[2, 3], keepdim=True) 是一个用于计算张量 x 在指定维度上平均值的操作。具体来说,这里是对 x 的第2个和第3个维度(索引从0开始)进行平均。

        参数 dim=[2, 3] 指示了要在哪些维度上进行平均。在这个例子中,我们将对第2和第3个维度进行平均。

        参数 keepdim=True 意味着输出张量的维度将与输入张量 x 相同,只不过在第2和第3个维度上大小为1。如果 keepdim=False(默认值),则输出张量将不包含被平均掉的维度。

        举个例子,假设 x 的形状是 [batch_size, channels, height, width](这是一个常见的四维张量形状,用于表示一批图像数据,其中每个图像有多个通道,如RGB)。那么,x.mean(dim=[2, 3], keepdim=True) 的结果将是一个形状为 [batch_size, channels, 1, 1] 的张量,其中每个 [channels, 1, 1] 的块都是对应 [channels, height, width] 块在 height 和 width 维度上的平均值。

        这样,你就可以保留原始张量的批量大小和通道数,同时获得每个图像(或每个通道的图像)的平均值。

        在前向传播中计算y时

  1. x_minus_mu_square
    这个变量表示输入张量 x 与其在通道维度(第2和第3维度)上的均值之差的平方。即,它计算了每个像素值与通道内所有像素均值的差的平方。这可以被视为一种衡量每个像素与通道内其他像素差异程度的度量。

  2. x_minus_mu_square.sum(dim=[2, 3], keepdim=True)
    这个操作计算了 x_minus_mu_square 在第2和第3维度(即高度和宽度维度)上的和,并保持这两个维度的维度数(由于 keepdim=True)。结果是一个形状为 [batch_size, channels, 1, 1] 的张量,其中每个 [channels, 1, 1] 的块表示对应通道内所有像素的 x_minus_mu_square 的和。

  3. 除以 n
    n = w * h - 1(在代码的前面部分定义),其中 w 和 h 分别是输入张量 x 的宽度和高度。这个除以 n 的操作实际上是在计算每个通道内所有像素的 x_minus_mu_square 的平均值(注意这里减去了1,可能是为了避免中心像素的权重过高,但这也取决于SimAM的具体实现和目的)。

  4. 加上 self.e_lambda
    self.e_lambda 是一个小的常数,用于防止分母为零或接近零的情况,从而提高数值稳定性。

  5. 除以 4 并加 0.5
    这两个操作是对上述计算得到的值进行缩放和平移。除以4可能是一种归一化操作,以确保后续计算的数值范围在可接受的范围内。加0.5可能是为了确保结果始终为正(因为sigmoid函数的输入通常在0到1之间)。

        总之,这段代码的目的是基于输入张量 x 中每个像素与其通道内其他像素的差异程度来计算注意力权重 y,从而实现对不同特征的选择性强调或抑制。

  • 18
    点赞
  • 17
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

钧尘

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值