out = out.view(b, -1)
out = self.SEblock(out)
out = out.view(b, c, 1, 1)
return out * x
class SEInvertedBottleneck(nn.Module):
def init(self, in_channels, mid_channels, out_channels, kernel_size, stride,activate, use_se, se_kernel_size=1):
super(SEInvertedBottleneck, self).init()
self.stride = stride
self.use_se = use_se
mid_channels = (in_channels * expansion_factor)
self.conv = Conv1x1BNActivation(in_channels, mid_channels,activate)
self.depth_conv = ConvBNActivation(mid_channels, mid_channels, kernel_size,stride,activate)
if self.use_se:
self.SEblock = SqueezeAndExcite(mid_channels, mid_channels, se_kernel_size)
self.point_conv = Conv1x1BNActivation(mid_channels, out_channels,activate)
if self.stride == 1:
self.shortcut = Conv1x1BN(in_channels, out_channels)
def forward(self, x):
out = self.depth_conv(self.conv(x))
if self.use_se:
out = self.SEblock(out)
out = self.point_conv(out)
out = (out + self.shortcut(x)) if self.stride == 1 else out
return out
class MobileNetV3(nn.Module):
def init(self, num_classes=1000,type=‘large’):
super(MobileNetV3, self).init()
self.type = type
self.first_conv = nn.Sequential(
nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, stride=2, padding=1),
nn.BatchNorm2d(16),
HardSwish(inplace=True),
)
if type==‘large’:
self.large_bottleneck = nn.Sequential(
SEInvertedBottleneck(in_channels=16, mid_channels=16, out_channels=16, kernel_size=3, stride=1,activate=‘relu