1. SqueezeNet
Iandola, Forrest N., et al. “SqueezeNet: AlexNet-level accuracy with 50x fewer parameters and< 0.5 MB model size.” arXiv preprint arXiv:1602.07360 (2016).
本文提出了一种轻量化的图像分类网络:SqueezeNet,正如文章的标题所写,参数量比AlexNet少50倍。可以被部署于嵌入式设备,在分布式训练时减少通信开销。
文章的思想来源于三个设计网络结构的策略:
- Replace 3x3 filters with 1x1 filters (尽量用1x1卷积替换3x3卷积)
- Decrease the number of input channels to 3x3 filters (3x3卷积的输入通道数尽可能减少)
- Downsample late in the network so that convolution layers have large activation
maps (下采样放到网络后面做,这样可以使得网络中间特征图的尺寸足够大,这有利于提高模型精度)
其中策略1,2可以降低参数量,策略3是用于保证模型的精确度。
基于这三个策略,作者设计了如上图所示的网络模块,称为Fire Module。输入先经过squeeze层的1x1卷积进行降维处理,这里降的是通道维,因为后面需要用3x3卷积进行处理,遵循策略2需要减少3x3卷积输入的通道数。expand层由1x1卷积核3x3卷积混合组成,然后将两种卷积的输出在通道维进行拼接。整个过程都没有下采样,因此输出的图像尺寸和输入一样,但是输出的通道数就变成了expand层1x1卷积的输出通道数加上3x3卷积的输出通道数。
以下给出Fire Module的代码:
class Fire(nn.Module):
def __init__(self, inplanes, squeeze_planes, expland1x1_planes, expand3x3_planes):
# 输出的形状为 [batch_size, expand1x1_palnes + expand3x3_planes, H, W]
super(Fire, self).__init__()
self.squeeze = nn.Conv2d(in_channels=inplanes,
out_channels=squeeze_planes,
kernel_size=1) # 用1x1卷积减少通道数
self.squeeze_activation = nn.ReLU(inplace=True)
# 混合卷积,不下采样
self.expand1x1 = nn.Conv2d(in_channels=squeeze_planes,
out_channels=expland1x1_planes,
kernel_size=1)
self.expand1x1_activation = nn.ReLU(inplace=True)
self.expand3x3 = nn.Conv2d(in_channels=squeeze_planes,
out_channels=expand3x3_planes,
kernel_size=3,
padding=1)
self.expand3x3_activation = nn.ReLU(inplace=True)
def forward(self, x):
x = self.squeeze(x)
x = self.squeeze_activation(x)
y1 = self.expand1x1_activation(self.expand1x1(x))
y2 = self.expand3x3_activation(self.expand3x3(x))
return torch.cat((y1, y2), dim=1) # 通道维上拼接
将若干Fire Module进行叠加,其中用最大池化层进行下采样:
class SqueezeNet(nn.Module):
def __init__(self, version="1.0", num_classes=10):
super(SqueezeNet, self).__init__()
self.num_classes = num_classes
if version == "1.0":
self.features = nn.Sequential(
nn.Conv2d(in_channels=1, out_channels=96, kernel_size=7, stride=2),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),
Fire(96, 16, 64, 64),
Fire(128, 16, 64, 64),
Fire(128, 32, 128, 128),
nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),
Fire(256, 32, 128, 128),
Fire(256, 48, 192, 192),
Fire(384, 48, 192, 192),
Fire(384, 64, 256, 256),
nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),
Fire(512, 64, 256, 256),
)
elif version == "1.1":
self.features = nn.Sequential(
nn.Conv2d(1, 64, kernel_size=3, stride=2),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),
Fire(64, 16, 64, 64),
Fire(128, 16, 64, 64),
nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),
Fire(128, 32, 128, 128),
Fire(256, 32, 128, 128),
nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True),
Fire(256, 48, 192, 192),
Fire(384, 48, 192, 192),
Fire(384, 64, 256, 256),
Fire(512, 64, 256, 256),
)
self.classifier = nn.Sequential(
nn.Dropout(0.5),
nn.Conv2d(512, self.num_classes, kernel_size=1),
nn.ReLU(inplace=True),
nn.AdaptiveAvgPool2d((1, 1))
)
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_uniform_(m.weight.data)
if m.bias is not None:
m.bias.data.zero_()
def forward(self, x):
x = self.features(x)
x = self.classifier(x)
return torch.flatten(x, 1)
最后再附上带有残差连接结构的:
class Squeeze(nn.Module):
def __init__(self, num_classes=10):
super(Squeeze, self).__init__()
self.conv1 = nn.Conv2d(1, 96, kernel_size=3, stride=2)
self.relu = nn.ReLU(inplace=True)
self.max1 = nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True)
self.fire2 = Fire(96, 16, 64, 64)
self.fire3 = Fire(128, 16, 64, 64)
self.fire4 = Fire(128, 16, 128,128)
self.max2 = nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True)
self.fire5 = Fire(256, 32, 128, 128)
self.fire6 = Fire(256, 32, 192, 192)
self.fire7 = Fire(384, 64, 192, 192)
self.fire8 = Fire(384, 64, 256, 256)
self.max3 = nn.MaxPool2d(kernel_size=3, stride=2, ceil_mode=True)
self.fire9 = Fire(512, 64, 256, 256)
self.classifier = nn.Sequential(
nn.Dropout(0.5),
nn.Conv2d(512, num_classes, kernel_size=1),
nn.ReLU(inplace=True),
nn.AdaptiveAvgPool2d((1, 1))
)
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_uniform_(m.weight.data)
if m.bias is not None:
m.bias.data.zero_()
def forward(self, x):
x = self.relu(self.conv1(x))
x = self.max1(x)
x = self.fire2(x)
x = x + self.fire3(x)
x = self.fire4(x)
x = self.max2(x)
x = x + self.fire5(x)
x = self.fire6(x)
x = x + self.fire7(x)
x = self.fire8(x)
x = self.max3(x)
x = x + self.fire9(x)
x = self.classifier(x)
return torch.flatten(x, 1)