代码
import torch
import torch.nn as nn
class Channel_Attention(nn.Module):
def __init__(self, in_channels):
super().__init__()
self.max_pool = nn.AdaptiveMaxPool2d(1)
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.fc1 = nn.Conv2d(in_channels, in_channels//16, 1)
self.relu = nn.ReLU()
self.fc2 = nn.Conv2d(in_channels//16, in_channels, 1)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
max_pool = self.max_pool(x)
avg_pool = self.avg_pool(x)
fc_max = self.fc2(self.relu(self.fc1(max_pool)))
fc_avg = self.fc2(self.relu(self.fc1(avg_pool)))
fc_plus = fc_max + fc_avg
sig = self.sigmoid(fc_plus)
return sig
class Spatial_Attention(nn.Module):
def __init__(self, kernel_size=7):
super().__init__()
self.conv = nn.Conv2d(2, 1, kernel_size, 1, 3, 1)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
avg = torch.mean(x, dim=1, keepdim=True)
max, _ = torch.max(x, dim=1, keepdim=True)
avg_max = torch.cat([avg, max], dim=1)
conv = self.conv(avg_max)
sig = self.sigmoid(conv)
return sig
class CBAM(nn.Module):
def __init__(self, in_channels):
super().__init__()
self.channel_atten = Channel_Attention(in_channels)
self.spatial_atten = Spatial_Attention()
def forward(self, x):
channel_atten = self.channel_atten(x)
x = channel_atten*x
spatial_atten = self.spatial_atten(x)
x = spatial_atten*x
return x
图
![cbam_overview](https://i-blog.csdnimg.cn/blog_migrate/1e97ed7b89719cc79ae0af2a3ff5e265.png)
![cbam_modules](https://i-blog.csdnimg.cn/blog_migrate/9b5a09ad7fedbccafaae2cba0ffe4d84.png)