论文细节可以参考link:https://zhuanlan.zhihu.com/p/65529934
- CBA-Module
-
Channel Attention Module
class CAModule(nn.Module):
'''Channel Attention Module'''
def __init__(self, channels, reduction):
super(CAModule, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.max_pool = nn.AdaptiveMaxPool2d(1)
self.shared_mlp = nn.Sequential(nn.Conv2d(channels, channels // reduction, kernel_size=1, padding=0, bias=False),
nn.ReLU(inplace=True),
nn.Conv2d(channels // reduction, channels, kernel_size=1, padding=0, bias=False))
self.sigmoid = nn.Sigmoid()
def forward(self, x):
input = x
avg_pool = self.avg_pool(x)
max_pool = self.max_pool(x)
x = self.shared_mlp(avg_pool) + self.shared_mlp(max_pool)
x = self.sigmoid(x)
return input * x
- Spatial Attention Module
class SAModule(nn.Module):
'''Spatial Attention Module'''
def __init__(self):
super(SAModule, self).__init__()
self.conv = nn.Conv2d(2, 1, kernel_size=3, padding=1, bias=False) # default kernel_size = 7 x 7
self.sigmoid = nn.Sigmoid()
def forward(self, x):
input = x
avg_c = torch.mean(x, 1, True)
max_c, _ = torch.max(x, 1, True)
x = torch.cat((avg_c, max_c), 1)
x = self.conv(x)
x = self.sigmoid(x)
return input * x
- ResBlock + CBAM
class BasicBlock(nn.Module):
expansion = 1
def __init__(self, inplanes, planes, stride=1, downsample=None):
super(BasicBlock, self).__init__()
self.conv1 = conv3x3(inplanes, planes, stride)
self.bn1 = nn.BatchNorm2d(planes)
self.relu = nn.ReLU(inplace=True)
self.conv2 = conv3x3(planes, planes)
self.bn2 = nn.BatchNorm2d(planes)
self.ca = CAModule(planes)
self.sa = SAModule()
self.downsample = downsample
self.stride = stride
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.ca(out) * out
out = self.sa(out) * out
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
- Conclusion
- 将CBA-Module集成到不同的backbone中,模型的性能(在分类、检测)都有一致地提升;
- 从对比实验的结果来看,先进行Channel attention,后进行Spatial attention的效果会好一些;
- 对于Spatial attention而言,采用7x7的kernel,结果似乎更好一些;