按照这篇博客(https://blog.csdn.net/qq_38410428/article/details/103694759)添加CBAM时,出现维度不匹配的提示:
RuntimeError: Given groups=1, weight of size [64, 64, 1, 1], expected input[32, 1, 1, 1] to have 64 channels, but got 1 channels instead
对代码进行调试发现,问题出在BOTTLENECK 模块之前,经过CA和SA后特征图就会变成(B,1,1,1),然后对代码进行修改测试,最终解决问题。
修改如下:将代码中SA模块与CA模块得到的输出(B,C,1,1)和 (B,1,H,W),原代码是将该输出放在resnet CLASS里的forward与本身特征图(B,C,H,W)进行相乘,现改为直接在SA模块里进行相乘,并对resnet内里的forward 进行相应的修改。
猜测出现这个问题是由于Pytorch版本的原因。
各位应该进行对自己的代码进行相应的修改即可。欢迎大家一起讨论。
代码修改如下:
将CA和SA中的模块修改如下:
class ChannelAttention(nn.Module):
def __init__(self, in_planes, ratio=16):
super(ChannelAttention, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.max_pool = nn.AdaptiveMaxPool2d(1)
self.fc = nn.Sequential(nn.Conv2d(in_planes, in_planes // 16, 1, bias=False),
nn.ReLU(),
nn.Conv2d(in_planes // 16, in_planes, 1, bias=False))
self.sigmoid = nn.Sigmoid()
def forward(self, x):
avg_out = self.fc(self.avg_pool(x))
max_out = self.fc(self.max_pool(x))
out = avg_out + max_out
out = self.sigmoid(out)
# out = out * x
# return self.sigmoid(out)
out = out * x
return out
class SpatialAttention(nn.Module):
def __init__(self, kernel_size=7):
super(SpatialAttention, self).__init__()
self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=kernel_size // 2, bias=False)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
tem = x
avg_out = torch.mean(x, dim=1, keepdim=True)
max_out, _ = torch.max(x, dim=1, keepdim=True)
x = torch.cat([avg_out, max_out], dim=1)
x = self.conv1(x)
x = self.sigmoid(x)
x = tem*x
return x
RESNET class 类里修改如下:
class ResNet(nn.Module):
def _forward_impl(self, x: Tensor) -> Tensor:
# See note [TorchScript super()]
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
# x = self.ca(x) * x
x = self.ca(x)
x = self.sa(x)
# x = self.sa(x) * x
x = self.maxpool(x)
x1 = self.layer1(x)
x2 = self.layer2(x1)
x3 = self.layer3(x2)
x4 = self.layer4(x3)
# x4 = self.ca1(x4) * x4
x4 = self.ca1(x4)
# last = self.sa1(x4) * x4
last = self.sa1(x4)
last = self.avgpool(last)
# last = self.avgpool(x4)
last = torch.flatten(last,1)
last = self.fc(last)
# x = self.avgpool(x)
# x = torch.flatten(x, 1)
# x = self.fc(x)
# return x
return last
def forward(self, x: Tensor) -> Tensor:
return self._forward_impl(x)