https://github.com/Sunshine066/my-face-althorithm
完整代码已经放到github,包括添加现存各个注意力机制的mobiefacenet,欢迎查看指正。有问题沟通
下面只放部分注意力机制的代码。
class SqueezeExcitation(nn.Module):
# squeeze_factor: int = 4:第一个FC层节点个数是输入特征矩阵的1/4
def __init__(self, input_c: int, squeeze_factor: int = 4):
super(SqueezeExcitation, self).__init__()
# 第一个FC层节点个数,也要是8的整数倍
squeeze_c = _make_divisible(input_c // squeeze_factor, 8)
# print("yttest %d" %squeeze_c)
# 通过卷积核大小为1x1的卷积替代FC层,作用相同
self.fc1 = nn.Conv2d(input_c, squeeze_c, 1)
self.fc2 = nn.Conv2d(squeeze_c, input_c, 1)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# x有很多channel,通过output_size=(1, 1)实现每个channel变成1个数字
scale = nn.functional.adaptive_avg_pool2d(x, output_size=(1, 1))
scale = self.fc1(scale)
scale = nn.functional.relu(scale, inplace=True)
scale = self.fc2(scale)
# 此处的scale就是第二个FC层输出的数据
scale = nn.functional.hardsigmoid(scale, inplace=True)
return scale * x # 和原输入相乘,得到SE模块的输出
yt add start end
yt add start,3Dse moudle,空间注意力模块
class conv_block(nn.Module):
def __init__(self,ch_in,ch_out):
super(conv_block,self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(ch_in, ch_out, kernel_size=3,stride=1,padding=1,bias=True),
nn.BatchNorm2d(ch_out),
nn.ReLU(inplace=True),
nn.Conv2d(ch_out, ch_out, kernel_size=3,stride=1,padding=1,bias=True),
nn.BatchNorm2d(ch_out),
nn.ReLU(inplace=True)
)
def forward(self,x):
x = self.conv(x)
return x
class SqueezeAttentionBlock(nn.Module):
def __init__(self, ch_in, ch_out):
super(SqueezeAttentionBlock, self).__init__()
self.avg_pool = nn.AvgPool2d(kernel_size=2, stride=2)
self.conv = conv_block(ch_in, ch_out)
self.conv_atten = conv_block(ch_in, ch_out)
self.upsample = nn.Upsample(scale_factor=2)
def forward(self, x):
print(x.shape)
x_res = self.conv(x)
print(x_res.shape)
y = self.avg_pool(x)
print(y.shape)
y = self.conv_atten(y)
print(y.shape)
y = self.upsample(y)
print(y.shape[0])
# yt add, 解决奇数大小特征图下采样、上采样后无法恢复原始特征图大小的问题
if y.shape != x_res.shape:
y = torch.nn.functional.interpolate(y, size=[y.shape[2]+1,y.shape[3]+1], scale_factor=None, mode='nearest', align_corners=None,
recompute_scale_factor=None)
print(y.shape, x_res.shape)
return (y * x_res) + y