class Spartial_Attention(nn.Module):
def __init__(self, kernel_size):
super(Spartial_Attention, self).__init__()
assert kernel_size % 2 == 1, "kernel_size = {}".format(kernel_size)
padding = (kernel_size - 1) // 2
self.__layer = nn.Sequential(
nn.Conv2d(2, 1, kernel_size=kernel_size, padding=padding),
nn.Sigmoid(),
)
def forward(self, x):
avg_mask = torch.mean(x, dim=1, keepdim=True)
max_mask, _ = torch.max(x, dim=1, keepdim=True)
mask = torch.cat([avg_mask, max_mask], dim=1)
mask = self.__layer(mask)
return x * mask
class Channel_Attention(nn.Module):
def __init__(self, channel, r):
super(Channel_Attention, self).__init__()
self.__avg_pool = nn.AdaptiveAvgPool2d((1, 1))
self.__max_pool = nn.AdaptiveMaxPool2d((1, 1))
self.__fc = nn.Sequential(
nn.Conv2d(channel, channel//r, 1, bias=False),
nn.ReLU(True),
nn.Conv2d(channel//r, channel, 1, bias=False),
)
self.__sigmoid = nn.Sigmoid()
def forward(self, x):
y1 = self.__avg_pool(x)
y1 = self.__fc(y1)
y2 = self.__max_pool(x)
y2 = self.__fc(y2)
y = self.__sigmoid(y1+y2)
return x * y
参考:https://blog.csdn.net/weixin_42907473/article/details/106525668