添加CBAM 出现通道维度不匹配问题

按照这篇博客(https://blog.csdn.net/qq_38410428/article/details/103694759)添加CBAM时,出现维度不匹配的提示:

RuntimeError: Given groups=1, weight of size [64, 64, 1, 1], expected input[32, 1, 1, 1] to have 64 channels, but got 1 channels instead

对代码进行调试发现,问题出在BOTTLENECK 模块之前,经过CA和SA后特征图就会变成(B,1,1,1),然后对代码进行修改测试,最终解决问题。
修改如下:将代码中SA模块与CA模块得到的输出(B,C,1,1)和 (B,1,H,W),原代码是将该输出放在resnet CLASS里的forward与本身特征图(B,C,H,W)进行相乘,现改为直接在SA模块里进行相乘,并对resnet内里的forward 进行相应的修改。
猜测出现这个问题是由于Pytorch版本的原因。
各位应该进行对自己的代码进行相应的修改即可。欢迎大家一起讨论。

代码修改如下:

将CA和SA中的模块修改如下:

class ChannelAttention(nn.Module):
    def __init__(self, in_planes, ratio=16):
        super(ChannelAttention, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.max_pool = nn.AdaptiveMaxPool2d(1)

        self.fc = nn.Sequential(nn.Conv2d(in_planes, in_planes // 16, 1, bias=False),
                                nn.ReLU(),
                                nn.Conv2d(in_planes // 16, in_planes, 1, bias=False))
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        avg_out = self.fc(self.avg_pool(x))
        max_out = self.fc(self.max_pool(x))
        out = avg_out + max_out
        out = self.sigmoid(out)
        # out = out * x
        # return self.sigmoid(out)
        out = out * x
        return out


class SpatialAttention(nn.Module):
    def __init__(self, kernel_size=7):
        super(SpatialAttention, self).__init__()

        self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=kernel_size // 2, bias=False)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        tem = x
        avg_out = torch.mean(x, dim=1, keepdim=True)
        max_out, _ = torch.max(x, dim=1, keepdim=True)
        x = torch.cat([avg_out, max_out], dim=1)
        x = self.conv1(x)
        x = self.sigmoid(x)
        x = tem*x

        return x
     

RESNET class 类里修改如下:

class ResNet(nn.Module):
    def _forward_impl(self, x: Tensor) -> Tensor:

        # See note [TorchScript super()]
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)



        # x = self.ca(x) * x
        x = self.ca(x)
        x = self.sa(x)
        # x = self.sa(x) * x




        x = self.maxpool(x)


        x1 = self.layer1(x)
        x2 = self.layer2(x1)
        x3 = self.layer3(x2)
        x4 = self.layer4(x3)

        # x4 = self.ca1(x4) * x4
        x4 = self.ca1(x4)
        # last = self.sa1(x4) * x4
        last = self.sa1(x4)


        last = self.avgpool(last)
        # last = self.avgpool(x4)
        last = torch.flatten(last,1)
        last = self.fc(last)
        # x = self.avgpool(x)
        # x = torch.flatten(x, 1)
        # x = self.fc(x)
        # return x

        return last


    def forward(self, x: Tensor) -> Tensor:

        return self._forward_impl(x)
  • 0
    点赞
  • 10
    收藏
    觉得还不错? 一键收藏
  • 25
    评论
评论 25
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值