先附上SimAM注意力机制代码
#SimAM
class SimAM(torch.nn.Module):
def __init__(self, channels=None, out_channels=None, e_lambda=1e-4):
super(SimAM, self).__init__()
self.activaton = nn.Sigmoid()
self.e_lambda = e_lambda
def __repr__(self):
s = self.__class__.__name__ + '('
s += ('lambda=%f)' % self.e_lambda)
return s
@staticmethod
def get_module_name():
return "simam"
def forward(self, x):
b, c, h, w = x.size()
n = w * h - 1
x_minus_mu_square = (x - x.mean(dim=[2, 3], keepdim=True)).pow(2)
y = x_minus_mu_square / (4 * (x_minus_mu_square.sum(dim=[2, 3], keepdim=True) / n + self.e_lambda)) + 0.5
return x * self.activaton(y)
这段代码定义了一个名为SimAM
的PyTorch模块,它实现了SimAM(Similarity-Aware Activation Module)的功能。SimAM是一个注意力机制模块,旨在通过计算特征图之间的相似性来增强有用的特征。
-
类定义:
class SimAM(torch.nn.Module):
:定义了一个名为SimAM
的类,它继承了PyTorch的nn.Module
基类。
-
初始化方法:
def __init__(self, channels=None, out_channels=None, e_lambda=1e-4):
:类的构造函数,其中channels
和out_channels
参数在此实现中未被使用(可能是为了与其他注意力模块保持一致的接口),而e_lambda
是一个小的常数,用于防止分母为零。self.activaton = nn.Sigmoid()
:定义了一个sigmoid激活函数self.e_lambda = e_lambda
:保存了传入的e_lambda
值。
-
字符串表示方法:
def __repr__(self):
:定义了一个方法,用于返回类的字符串表示。在这个方法中,它只返回了类名和e_lambda
的值。
-
静态方法:
@staticmethod
:这是一个装饰器,表示下面的方法是静态方法。def get_module_name():
:返回模块的名称,即"simam"。但注意,由于这是静态方法,它应该有一个self
参数或直接使用类名调用。但在此代码中,它作为实例方法也可以工作,但更合适的做法可能是作为类方法(即使用@classmethod
装饰器)。
-
前向传播方法:
def forward(self, x):
:定义了前向传播方法,它接收一个输入张量x
并返回处理后的张量。b, c, h, w = x.size()
:获取输入张量x
的维度,其中b
是批处理大小,c
是通道数,h
和w
是高和宽。n = w * h - 1
:计算除中心像素外的像素数量(但这里减1的逻辑可能取决于具体的应用场景)。x_minus_mu_square
:计算每个通道中每个像素与通道均值的差的平方。y
:计算了一个加权值,其中分母是x_minus_mu_square
的和(除以n
并加上e_lambda
以避免除以零),然后整个表达式再除以4并加0.5。这个计算可能是为了得到一个介于0和1之间的值,以便后续的sigmoid激活。return x * self.activaton(y)
:使用sigmoid激活函数对y
进行激活,并将结果与原始输入x
相乘,得到最终的输出。
其中,在前向传播中计算x_minus_mu_square时,x.mean(dim=[2, 3], keepdim=True)
是一个用于计算张量 x
在指定维度上平均值的操作。具体来说,这里是对 x
的第2个和第3个维度(索引从0开始)进行平均。
参数 dim=[2, 3]
指示了要在哪些维度上进行平均。在这个例子中,我们将对第2和第3个维度进行平均。
参数 keepdim=True
意味着输出张量的维度将与输入张量 x
相同,只不过在第2和第3个维度上大小为1。如果 keepdim=False
(默认值),则输出张量将不包含被平均掉的维度。
举个例子,假设 x
的形状是 [batch_size, channels, height, width]
(这是一个常见的四维张量形状,用于表示一批图像数据,其中每个图像有多个通道,如RGB)。那么,x.mean(dim=[2, 3], keepdim=True)
的结果将是一个形状为 [batch_size, channels, 1, 1]
的张量,其中每个 [channels, 1, 1]
的块都是对应 [channels, height, width]
块在 height
和 width
维度上的平均值。
这样,你就可以保留原始张量的批量大小和通道数,同时获得每个图像(或每个通道的图像)的平均值。
在前向传播中计算y时
-
x_minus_mu_square:
这个变量表示输入张量x
与其在通道维度(第2和第3维度)上的均值之差的平方。即,它计算了每个像素值与通道内所有像素均值的差的平方。这可以被视为一种衡量每个像素与通道内其他像素差异程度的度量。 -
x_minus_mu_square.sum(dim=[2, 3], keepdim=True):
这个操作计算了x_minus_mu_square
在第2和第3维度(即高度和宽度维度)上的和,并保持这两个维度的维度数(由于keepdim=True
)。结果是一个形状为[batch_size, channels, 1, 1]
的张量,其中每个[channels, 1, 1]
的块表示对应通道内所有像素的x_minus_mu_square
的和。 -
除以 n:
n = w * h - 1
(在代码的前面部分定义),其中w
和h
分别是输入张量x
的宽度和高度。这个除以n
的操作实际上是在计算每个通道内所有像素的x_minus_mu_square
的平均值(注意这里减去了1,可能是为了避免中心像素的权重过高,但这也取决于SimAM的具体实现和目的)。 -
加上 self.e_lambda:
self.e_lambda
是一个小的常数,用于防止分母为零或接近零的情况,从而提高数值稳定性。 -
除以 4 并加 0.5:
这两个操作是对上述计算得到的值进行缩放和平移。除以4可能是一种归一化操作,以确保后续计算的数值范围在可接受的范围内。加0.5可能是为了确保结果始终为正(因为sigmoid函数的输入通常在0到1之间)。
总之,这段代码的目的是基于输入张量 x
中每个像素与其通道内其他像素的差异程度来计算注意力权重 y
,从而实现对不同特征的选择性强调或抑制。