Squeeze Excitation 模块 (Pytorch)
(一)SE模块结构图
主要包括Squeeze(压缩)和Excitation(激活)两个模块
其中,squeeze_factor为第一个全连接层压缩比例,可以减少参数量。
(二)pytorch实现SE模块
class SqueezeExcitation(nn.Module):
# 参数:in_channel:输入通道数, out_channel:输出通道数, squeeze_factor=4:第一个squeeze的压缩比
def __init__(self, in_channel, out_channel, squeeze_factor=4):
super(SqueezeExcitation, self).__init__()
# 自使用平均池化,output_size指定输出的H,W的大小
self.adppool = nn.AdaptiveAvgPool2d(output_size=1)
squeeze_c = in_channel // squeeze_factor
self.fc1 = nn.Conv2d(in_channel, squeeze_c, kernel_size=1)
self.ac1 = nn.SiLU()
self.fc2 = nn.Conv2d(squeeze_c, out_channel, kernel_size=1)
self.ac2 = nn.Sigmoid()
def forward(self, x):
out = self.adppool(x)
out = self.fc1(out)
out = self.ac1(out)
out = self.fc2(out)
out = self.ac2(out)
return out*x