LeNex、AlexNet和VGG的设计共同点是:先有卷积层构成的模块充分抽取空间特征,再以由全连接层构成的模块来输出分类结果。
NiN提出了另外一个思路,即串联多个由卷积层和“全连接层(1x1卷积层)”构成的小网络老构建一个深层网络。
NiN块
卷积层的输入和输出通常是四维数组(样本,通道,高,宽),而全连接层的输入和输出则通常是二维数组(样本,特征)。如果想在全连接层后再接上卷积层,则需要将全连接层的输出变换为四维。1x1卷积层可以看成全连接层,其中空间维度(高,宽)上的每个元素相当于样本,通道相当于特征。NiN使用1x1卷积层来替代全连接层,从而使空间信息能够自然传递到后面的层中去。
NiN块是NiN中的基础块。由一个卷积层加两个充当全连接层的1x1卷积层串联而成。其中第一个卷积层的超参数可以自行设置,而第二个和第三个卷积层的超参数一般是固定的。
NiN模型
NiN与AlexNet有类似之处。NiN使用卷积窗口形状分别为11x11,5x5和3x3的卷积层,相应的输出通道数也与AlexNet中的一致。每个NiN块后接一个步幅为2,窗口形状为3x3的最大池化层。
NiN与AlexNet的不同点是:NiN去掉看AlexNet最后的3个全连接层,使用了输出通道数等于标签类别墅的NiN块,然后使用全局平均池化层对通道中所有元素求平均并直接用于分类。这里的全局平均池化层即窗口形状等于输入空间维形状的平均池化层。
NiN模块这样设计的好处是可以显著减少模型参数尺寸,从而缓解过拟合。然而,该设计有时会造成获得有效模型的训练时间的增加。
def nin_block(in_channel, out_channel, kernel_size, strides, padding):
layer = []
layer += [nn.Conv2d(in_channel, out_channel, kernel_size, strides,padding), nn.ReLU(inplace=True)]
layer += [nn.Conv2d(out_channel, out_channel, kernel_size=1),nn.ReLU(inplace=True)]
layer += [nn.Conv2d(out_channel, out_channel, kernel_size=1), nn.ReLU(inplace=True)]
return nn.Sequential(*layer)
layers = []
layers += [nin_block(1,96,kernel_size=11,strides=4,padding=0),nn.MaxPool2d(3,2)]
layers += [nin_block(96,256,kernel_size=5,strides=1,padding=2),nn.MaxPool2d(3,2)]
layers += [nin_block(256,384,kernel_size=3,strides=1,padding=1),nn.MaxPool2d(3,2),nn.Dropout(0.5)]
layers += [nin_block(384,10,kernel_size=3,strides=1,padding=1),nn.AdaptiveAvgPool2d()]
net_NiN = nn.Sequential(*layers)
x = torch.randn((1,1,224,224))
for layer in net_NiN:
x = layer(x)
print('output shape:\t', x.shape)
x= torch.flatten(x,start_dim=1)
print(x,x.shape)
output shape: torch.Size([1, 96, 54, 54])
output shape: torch.Size([1, 96, 26, 26])
output shape: torch.Size([1, 256, 26, 26])
output shape: torch.Size([1, 256, 12, 12])
output shape: torch.Size([1, 384, 12, 12])
output shape: torch.Size([1, 384, 5, 5])
output shape: torch.Size([1, 384, 5, 5])
output shape: torch.Size([1, 10, 5, 5])
output shape: torch.Size([1, 10, 1, 1])
tensor([[0.0592, 0.0000, 0.2903, 0.4870, 0.0000, 0.2750, 0.3790, 0.3298, 0.0000,
0.1780]], grad_fn=<AsStridedBackward>) torch.Size([1, 10])