Senet
# -*- encoding: utf-8 -*-
"""
@File : seNet.py
@Time : 2021-12-29 17:44
@Author : XD
@Email : gudianpai@qq.com
@Software: PyCharm
"""
import torch
from torch import nn
from torchsummary import summary
class seNet(nn.Module):
def __init__(self, channel, ratio = 16):
super(seNet, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Sequential(
nn.Linear(channel, channel // ratio, bias = False),
nn.ReLU(),
nn.Linear(channel // ratio, channel, bias = False),
nn.Sigmoid()
)
def forward(self, x):
b, c, h, w = x.size()
# b, c, h, w -> b, c, 1, 1
avg = self.avg_pool(x).view(b, c)
#b, c, h, w -> b, c // ratio -> b, c -> b, c, 1, 1
fc = self.fc(avg).view(b, c, 1, 1)
return x * fc
model = seNet(channel = 512)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.to(device)
print(model)
summary(model, input_size = [(512, 26, 26)], batch_size = 2, device = "cuda")