class ChannelAttention(nn.Module):
def __init__(self,channel,reduction=16):
super().__init__()
self.maxpool=nn.AdaptiveMaxPool2d(1)
self.avgpool=nn.AdaptiveAvgPool2d(1)
self.se=nn.Sequential(
nn.Conv2d(channel,channel//reduction,1,bias=False),
nn.ReLU(),
nn.Conv2d(channel//reduction,channel,1,bias=False)
)
self.sigmoid=nn.Sigmoid()
def forward(self, x) :
max_result=self.maxpool(x)
avg_result=self
ChannelAttention, SpatialAttention, and SEAttention代码
最新推荐文章于 2024-04-18 20:53:00 发布
本文介绍了在PyTorch中实现的三种注意力机制:Channel Attention、Spatial Attention和SEAttention,这些都是深度学习和神经网络中增强模型表现的重要技术。通过这些机制,模型能够更好地聚焦于输入数据的关键部分,提升特征学习和泛化能力。
摘要由CSDN通过智能技术生成