复杂度对模型的影响
- 时间复杂度决定了模型的训练/预测时间。如果复杂度过高,会导致模型训练和预测耗费大量时间,既无法快速的验证想法和改善模型,也无法做到快速的预测。
- 空间复杂度决定了模型的参数数量。由于维度灾难(curse of dimensionality)的限制,模型的参数越多,训练模型所需的数据量就越大,而现实生活中的数据集通常不会太大,这会导致模型的训练更容易过拟合。
时间复杂度
时间复杂度即模型的运算次数。
单个卷积层的时间复杂度:Time~O(M^2 * K^2 * Cin * Cout)
注1:为了简化表达式中的变量个数,这里统一假设输入和卷积核的形状都是正方形。
注2:严格来讲每层应该还包含1个Bias参数,这里为了简洁就省略了。
- M:输出特征图(Feature Map)的尺寸。
- K:卷积核(Kernel)的尺寸。
- Cin:输入通道数。
- Cout:输出通道数。
- 输出特征图尺寸又由输入尺寸X、卷积核尺寸K、Padding、 Stride 这四个参数所决定,表示如下:
M=(X - K + 2*Padding) / Stride + 1
空间复杂度
空间复杂度即模型的参数数量,体现为模型本身的体积。
Space~O(K^2 * Cin * Cout)
- 空间复杂度只与卷积核的尺寸K、通道数C相关。而与输入图片尺寸无关。
- 当我们需要裁剪模型时,由于卷积核的尺寸通常已经很小,而网络的深度又与模型的能力紧密相关,不宜过多削减,因此模型裁剪通常最先下手的地方就是通道数。
复杂度优化(以Inception系列为例)
InceptionV1(1*1卷积对通道数降维)
- InceptionV1借鉴了Network in Network的思想,在一个Inception Module中构造了四个并行的不同尺寸的卷积/池化模块(上图左),有效的提升了网络的宽度。但是这么做也造成了网络的时间和空间复杂度的激增。对策就是添加 1*1 卷积(上图右红色模块)将输入通道数先降到一个较低的值,再进行真正的卷积。
- 使用 1*1 卷积降维可以降低时间复杂度3倍以上。
- 另外在空间复杂度上,虽然降维引入了三组 1*1 卷积核的参数,但新增参数量仅占整体的 5%,影响并不大。
InceptionV1(使用GAP代替 Flatten)
- 全连接层可以视为一种特殊的卷积层,其卷积核尺寸K与输入图尺寸X一模一样。每个卷积核的输出特征图是一个标量点,即M=1。复杂度分析如下:
Time~O(1^2 * X^2 * Cin * Cout)
Space~O(X^2 * Cin * Cout) - 可见,与真正的卷积层不同,全连接层的空间复杂度与输入数据的尺寸密切相关。因此如果输入图像尺寸越大,模型的体积也就会越大,这显然是不可接受的。例如早期的VGG系列模型,其 90% 的参数都耗费在全连接层上。
- InceptionV1中使用的全局最大池化GAP改善了这个问题。由于每个卷积核输出的特征图在经过全局最大池化后都会直接精炼成一个标量点,因此全连接层的复杂度不再与输入图像尺寸有关,运算量和参数数量都得以大规模削减。复杂度分析如下:
Time~O(Cin * Cout)
Space~O(Cin * Cout)
InceptionV2(使用两个3*3卷积核级联代替5*5卷积分支)
- 根据上面提到的二维卷积输入输出尺寸关系公式,可知:对于同一个输入尺寸,单个 5*5 卷积的输出与两个 3*3 卷积级联输出的尺寸完全一样,即感受野相同。
- 同样根据上面提到的复杂度分析公式,可知:这种替换能够非常有效的降低时间和空间复杂度。我们可以把辛辛苦苦省出来的这些复杂度用来提升模型的深度和宽度,使得我们的模型能够在复杂度不变的前提下,具有更大的容量。
InceptionV3(使用N*1与1*N代替N*N卷积)
- InceptionV3 中提出了卷积的因式分解(Factorization),在确保感受野不变的前提下进一步简化。
Xception()
- 我们之前讨论的都是标准卷积运算,每个卷积核都对输入的所有通道进行卷积。
- Xception 模型挑战了这个思维定势,它让每个卷积核只负责输入的某一个通道,这就是所谓的 Depth-wise Separable Convolution。
- 从输入通道的视角看,标准卷积中每个输入通道都会被所有卷积核蹂躏一遍,而 Xception 中每个输入通道只会被对应的一个卷积核扫描,降低了模型的冗余度。
- 标准卷积与可分离卷积的时间复杂度对比:可以看到本质上是把连乘转化成为相加。
总结
通过上面的推导和经典模型的案例分析,我们可以清楚的看到其实很多创新点都是围绕模型复杂度的优化展开的,其基本逻辑就是乘变加。模型的优化换来了更少的运算次数和更少的参数数量,一方面促使我们能够构建更轻更快的模型(例如MobileNet),一方面促使我们能够构建更深更宽的网络(例如Xception),提升模型的容量。