个人感觉SE机制 ,其实在通道数量大的情况下,表现非常好。但是基于RESNET网络的时候,提取边界特征等,效果不是很好。所以提出了SEA注意力机制。
下面是我代码实现,代码进行了注释,大家可以看的懂:
大家点点赞,求求了
class SEA_Block(nn.Module):
def __init__(self, channel, reduction=16):
super(SEA_Block, self).__init__()
self.Max_pool = nn.AdaptiveMaxPool2d(1)
self.avg_pool = nn.AdaptiveAvgPool2d(1)
# reduced_channels = channel // reduction
self.conv_reduce = nn.Conv2d(2 * channel, (2 * channel)//reduction, kernel_size=1, bias=False)
self.relu = nn.ReLU(inplace=True)
self.conv_expand_left = nn.Conv2d((2 * channel)//reduction, channel, kernel_size=1, bias=False)
self.conv_expand_right = nn.Conv2d((2 * channel)//reduction, channel, kernel_size=1, bias=False)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
avg_pool = self.Max_pool(x) # (B, C, 1, 1)
max_pool = self.avg_pool(x) # (B, C, 1, 1)
concat_pool = torch.cat([avg_pool, max_pool], dim=1) # (B, 2C, 1, 1)
reduced = self.relu(self.conv_reduce(concat_pool)) # (B, 2C//r, 1, 1)
left = self.sigmoid(self.conv_expand_left(reduced)) # (B, C, 1, 1)
right = self.sigmoid(self.conv_expand_right(reduced)) # (B, C, 1, 1)
left_weighted = x * left # 左分支权重
right_right = left_weighted * right
out = right_right + x
return out