GhostNet pytorch版本 源码解析 (二)Ghost Bottlenecks类
Ghost Bottlenecks类
论文
stride=1:
Ghost Bottleneck由两个堆叠的Ghost modules构成
第一个Ghost modules:
负责增加通道数,是一个膨胀层
膨胀率:输出通道数与输入通道数的比值
第二个Ghost modules:
负责减小通道数,使得通道数与shortcut匹配
stride=2:
shortcut部分包括一个下采样层和一个stride=2的depthwise卷积层
代码
class SELayer(nn.Module):
def __init__(self, channel, reduction=4):
super(SELayer, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1) # 自适应平均池化,只改变[batch_size, channel, W, H]的后两维,channel不改变
self.fc = nn.Sequential(
nn.Linear(channel, channel // reduction),
nn.ReLU(inplace=True),
nn.Linear(channel // reduction, channel), )
def forward(self, x):
b, c, _, _ = x.size()
y = self.avg_pool(x).view(b, c) # 将池化后的输出变成形状为[b, c]的二维矩阵
y = self.fc(y).view(b, c, 1, 1) # 将fc后的输出尺寸变为[b, c, 1, 1]
y = torch.clamp(y, 0, 1) # 将y的值夹紧在0到1之间
return x * y
def depthwise_conv(inp, oup, kernel_size=3, stride=1, relu=False):
return nn.Sequential(
nn.Conv2d(inp, oup, kernel_size, stride, kernel_size//2, groups=inp, bias=False),
nn.BatchNorm2d(oup),
nn.ReLU(inplace=True) if relu else nn.Sequential(),
)
class GhostBottleneck(nn.Module):
def __init__(self, inp, hidden_dim, oup, kernel_size, stride, use_se):
super(GhostBottleneck, self).__init__()
assert stride in [1, 2] # stride只能是1或2
self.conv = nn.Sequential(
# pw
GhostModule(inp, hidden_dim, kernel_size=1, relu=True),
# dw
depthwise_conv(hidden_dim, hidden_dim, kernel_size, stride, relu=False) if stride==2 else nn.Sequential(),
# stride=2的时候加入depthwise卷积
# Squeeze-and-Excite
SELayer(hidden_dim) if use_se else nn.Sequential(),
# 如果使用SE块,则加入SE块
# pw-linear
GhostModule(hidden_dim, oup, kernel_size=1, relu=False),
# 在调用GhostModule的时候,令relu = False,
# 参考Mobilv2的Linear Bottlenecks,表明此时GhostModule的cheap_operation为线性操作
)
if stride == 1 and inp == oup: # 当stride=1且输入输出通道相等的时候,shortcut部分无内容
self.shortcut = nn.Sequential()
else: # 当stride=2或输入输出通道不相等的时候,shortcut部分加入卷积等操作,使得输入输出通道数相等
self.shortcut = nn.Sequential(
depthwise_conv(inp, inp, 3, stride, relu=True),
nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
nn.BatchNorm2d(oup),
)
def forward(self, x):
return self.conv(x) + self.shortcut(x) # 将两部分进行相加
如有错误希望大家批评指正!感谢!