GoogLeNet 是非常具有代表性的卷积神经网络之一,此网络由谷歌公司的 Christian Szegedy 等人设计提出1,并在 2014 年 ImageNet 挑战赛(图片分类)上夺取第一名桂冠2。GoogLeNet 不同于之前的经典网络,如 AlexNet,LeNet 或 VGG,它的设计更加颠覆传统。比如,它引入了 1x1 卷积核,使得整个网络虽然更深(22 层 Inception 网络层),但只需要训练很少的参数(500万),训练参数个数只有 AlexNet 的十分之一。
GoogLeNet 中的字母 L 之所以大写,据说是为了向前面的经典网络 LeNet 网络致敬。
1. 1x1 卷积核的意义
前篇文章介绍,卷积神经网络的卷积核大小通常是 奇数x奇数 的模式,并且以 3x3, 5x5, 7x7 为常见大小,这些卷积核能够识别出固定像素区域内的特征。试想一下,当我们使用 1x1 大小的卷积核时,其实就意味着后续的输出尺寸与原输出尺寸相同,只不过在深度上有所改变(与 filter 个数相关)。下面的图展示了利用 2 个 1x1x32 卷积核将原先 6x6x32 的结构变成了 6x6x2 的结构,深度减少。
因此,人们一般在网络中设置 1x1 的 filter 用来改变前一层神经网络的深度(即 Input 的 32)。在 Inception 网络层中,人们使用 1x1 的网络主要是为了减少所需的运算次数(当然,也有观点认为使用 1x1 卷积核使得感受野能堆叠更多的卷积,从而学习到更多的特征)。
2. 朴素版 Inception 网络层
Inception 的朴素版网络结构如下图所示,左侧部分描述了 Inception 的具体连接方式,右侧图是 Inception 的概要结构。假设输入是一个 28x28x192 的立方体,首先 Inception 会并行的进行点乘操作,对象分别是 3 种不同尺寸的filter,依次是 1x1x32, 3x3x32 和 5x5x32,他们的个数依次是 64,128 和 32。其中,3x3 和 5x5 的 filter 均使用了 Same Convolution 的填充方式保证其输出结果也是 28x28 的大小。此外,Inception 还增加了一个 Max pooling 的池化操作(当然,也使用了 Same Convolution 式的填充),该池化得到的结果是 28x28x192 的立方体。
最终,将 4 种类型的立方体 “连接” 起来(torch.cat),构成了 Inception 网络的最终输出,其大小为 28x28x416。
3. 改进版 Inception 网络层
朴素版的 Inception 网络层的一个很大的缺陷在于 5x5 卷积层上的计算量过大。如下左图所示,一个 28x28x192 的输入层,点乘 32 个 5x5x192 尺寸的 filter ,最终得到 28x28x32 的输出立方体,其所需要的乘法计算量为,
( 28 × 28 × 32 ) × ( 5 × 5 × 192 ) ≈ 1.2 亿 (28\times28\times32)\times(5\times5\times192) \approx 1.2 亿 (28×28×32)×(5×5×192)≈1.2亿
研究人员于是通过添加 1x1 的 filter 来巧妙地在不降低质量的情况下,降低计算量(如下右图所示),。首先输入连接 16 个大小为 1x1x192 的卷积核降低信道个数;其次再连接 32 个 5x5x16 的卷积核得到最终输出,其所需要的计算量为,
( 28 × 28 × 32 ) × ( 5 × 5 × 16 ) + ( 28 × 28 × 16 ) × ( 1 × 1 × 192 ) ≈ 1240 万 (28\times28\times32)\times(5\times5\times16) + (28\times28\times16)\times(1\times1\times192) \approx 1240 万 (28×28×32)×(5×5×16)+(28×28×16)×(1×1×192)≈1240万
由这个例子可知,可以通过 1x1 的卷积核降低信道深度,后面再接 5x5 或 3x3 的卷积核,这样会极大的减少网络的计算量。这就是改进版 Inception 的灵感来源。
再版 Inception 网络层的整体结果如下,他在 3x3,5x5 的卷积核之前还连接了 1x1 的卷积核。此外,在 Max-pooling 池化层的后面添加了 1x1 的卷积核,最终输出结果是 28x28x256 的输出立方体。
4. GoogLeNet 网络基本结构
在了解完 Inception 网络层(或 Inception 模块)后,理解 GoogLeNet 会更加的容易。总的来说,GoogLeNet 其实就是一层层 Inception 网络层的循环连接,一环套一环。
输入部分: 在 GoogLeNet 的最左侧是其输入部分,通过放大可以发现,输入(input)之后连接有 2 个卷积层,第一层使用 7x7 的卷积核以及 3x3 的最大池化层,第二层使用 1x1,3x3 的卷积层。图中绿色的 LocalRespNorm(LRN)指的是一种归一化的操作3,现在用的比较少了。
输出层: 输出层其实就是之前的 Inception 的堆叠,值得注意的是此处 GoogLeNet 的输出有 3 处,每次都会套用一个 Softmax 回归进行输出或辅助输出。
GoogLeNet 的细节细节请参见 45 等博客 ,总的来讲,GoogLeNet 的灵感就在于 Inception 模块的设计。后续的 GoogLeNet 也有很多改进的版本,如 v2,v3,v4,涉及到了更深层次的技术,如卷及分解,残差网络等等。个人觉得 1x1 矩阵的运用在 GoogLeNet (起码在 v1 版本)中起到了很重要的作用,这也使得神经网络的设计不再一味的追求深度,而是同时注重改进。
附录
下面的代码展现了 GoogLeNet 中重要的 Inception 的实现,首先定义个基本的卷积类 BasicConv2d
,
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchsummary
class BasicConv2d(nn.Module):
"""实现子卷积网络"""
def __init__(self, dim_in, dim_out, s, p):
super(BasicConv2d, self).__init__()
self.conv = nn.Conv2d(in_channels=dim_in, out_channels=dim_out, kernel_size=(s, s), padding=p)
self.bn = nn.BatchNorm2d(num_features=dim_out)
def forward(self, x):
x = self.conv(x)
out = self.bn(x)
return out
接着根据 BasicConv2d 来定义 Inception
结构,
class Inception(nn.Module):
"""朴素版 Inception 网络"""
def __init__(self, dim_in):
super(Inception, self).__init__()
self.branch1x1 = BasicConv2d(dim_in, 64, 1, 0)
self.branch3x3 = BasicConv2d(dim_in, 128, 3, 1)
self.branch5x5 = BasicConv2d(dim_in, 32, 5, 2)
self.branchpool = nn.MaxPool2d(kernel_size=(3, 3), stride=1, padding=1)
def forward(self, x):
x1 = self.branch1x1(x)
x3 = self.branch3x3(x)
x5 = self.branch5x5(x)
branch_pool = self.branchpool(x)
return torch.cat([x1, x3, x5, branch_pool], dim=1) # 拼接
改进版本的 Inception 被定义为 InceptionPro
,其实现如下,
class InceptionPro(nn.Module):
"""改进版 Inception 网络"""
def __init__(self, dim_in):
super(InceptionPro, self).__init__()
self.branch1x1 = BasicConv2d(dim_in, 64, 1, 0)
self.branch3x3_1 = BasicConv2d(dim_in, 96, 1, 0)
self.branch3x3_2 = BasicConv2d(96, 128, 3, 1)
self.branch5x5_1 = BasicConv2d(dim_in, 16, 1, 0)
self.branch5x5_2 = BasicConv2d(16, 32, 5, 2)
self.branchpool = nn.Sequential(
nn.MaxPool2d(kernel_size=(3, 3), stride=1, padding=1),
BasicConv2d(dim_in, 32, 1, 0) # 池化后还接了一个 1x1x192 的 filter
)
def forward(self, x):
x1 = self.branch1x1(x)
x3 = self.branch3x3_1(x)
x3 = self.branch3x3_2(x3)
x5 = self.branch5x5_1(x)
x5 = self.branch5x5_2(x5)
branch_pool = self.branchpool(x)
return torch.cat([x1, x3, x5, branch_pool], dim=1) # 拼接