我的模型有多快?——深度学习网络模型的运算复杂度、空间占用和内存访问情况计算

前段时间看了几个笔试题,涉及模型复杂度,主要是参数量和计算复杂度的问题。当时搜了一下感觉中文网上的内容比较乱。刚好本文是对神经网络模型资源消耗情况的一篇介绍,就不自己写了,把关键内容搬运一下。

原文见 http://machinethink.net/blog/how-fast-is-my-model/

深度网络的计算消耗是学术 paper 相对少见的话题。当然,早期网络精度不够的情况下讨论压缩也没有意义。工程师需要实现模型并让网络尽可能地在各类环境下工作,模型的资源消耗情况和运行速度非常关键。

原文以移动端的模型应用为例,列出了四个主要问题:

  • 空间占用——单个模型的参数文件要占用多大空间
  • 内存占用——运行在手机或平板上时需要占用多大的 RAM
  • 运行速度——尤其考虑实时的视频和大图像处理情形
  • 耗电情况——我可不想要暖手宝

原文内容还是挺丰富的。为了节约时间,博主只把关键的部分略作摘录和演绎,即网络中一般操作的运算量、内存占用的计算方法,其他的如 MobileNet 模型的优化就不做介绍了。如有需要请查阅原文。

原文以外的内容将斜体显示,以作区别。

正文

案例:作者的一位客户最近用 MobileNetV2 替换掉了 V1 模型,按理说V2 的计算量远小于 V1 ,

(网上随手找的图:https://blog.csdn.net/u011995719/article/details/79135818
这里写图片描述

然鹅实际情况却是 V2 比 V1 慢得多。

(注:可参考
https://www.zhihu.com/question/265709710/answer/299136290https://www.reddit.com/r/MachineLearning/comments/8a7sf6/d_mobilenet_v2_paper_said_depthwise_separable/
官方已经放出模型 https://github.com/tensorflow/models/tree/master/research/slim/nets/mobilenet 页面上也有实验测试结果。看完全文也会发现 V2 不比 V1 慢。作者这里有点标题党。)

为什么会这样呢?

1.计算消耗

可以用 FLOPS(floating point operations per second,每秒浮点运算数)来衡量模型的速度。另一种方法是 MACCs(multiply-accumulate operations,乘-加操作),也叫 MAdds。但说穿了,都是点积运算而已。

什么叫乘-加?神经网络里的运算大都是这样的:

y = w[0]*x[0] + w[1]*x[1] + w[2]*x[2] + ... + w[n-1]*x[n-1]

wx 都是向量,y 是标量。上式是全连接层或卷积层的典型运算。一次乘-加运算即一次乘法+一次加法运算,所以上式的 MACCs 是n

不过可以看到,加法运算的次数并非 n 而是 n-1 。但考虑 MACCs 时可以类比算法复杂度估算的 big-O ,即结果可以是近似的。

而换到 FLOPS 的情况,点积做了 2n-1 FLOPS,即 n-1 次加法和 n 次乘法。可以看到,MACCs 大约是 FLOPS 的一半。

1.1 全连接层

全连接层的计算

y = matmul(x, W) + b

权重 W W <script id="MathJax-Element-3" type="math/tex">W</script> 是一个 I×J I × J <script id="MathJax-Element-4" type="math/tex">I \times J</script> 矩阵,输入 x x <script id="MathJax-Element-5" type="math/tex">x</script> 是 I I <script id="MathJax-Element-6" type="math/tex">I</script> 维实值向量, b b <script id="MathJax-Element-7" type="math/tex">b</script> 是 J J <script id="MathJax-Element-8" type="math/tex">J</script> 维偏置。输出 y y <script id="MathJax-Element-9" type="math/tex">y</script> 也是 J J <script id="MathJax-Element-10" type="math/tex">J</script> 维实值向量。FC 层的 MACCs 也不难计算。

上文例子是向量与向量的点积,FC 是向量与矩阵的点积,每一组点积发生在输入 x x <script id="MathJax-Element-11" type="math/tex">x</script> 同权重 W W <script id="MathJax-Element-12" type="math/tex">W</script> 某一列之间,计有 I I <script id="MathJax-Element-13" type="math/tex">I</script> MACCs,一共要计算 J J <script id="MathJax-Element-14" type="math/tex">J</script> 组点积,所以 FC 层的 MACCs 总计 I×J I × J <script id="MathJax-Element-15" type="math/tex">I \times J</script>,跟权重的尺寸一致。

偏置项 b b <script id="MathJax-Element-16" type="math/tex">b</script> 对 MACCs 的影响可以忽略不计。而上面也提到 MACCs 中加法比乘法少一次, b b <script id="MathJax-Element-17" type="math/tex">b</script> 刚好补上了这个缺。

所以,对 I I <script id="MathJax-Element-18" type="math/tex">I</script> 的输入、权重为 I×J I × J <script id="MathJax-Element-19" type="math/tex">I \times J</script> 的权重矩阵和 J J <script id="MathJax-Element-20" type="math/tex">J</script> 的输出,MACCs 为 I×J I × J <script id="MathJax-Element-21" type="math/tex">I \times J</script> ,FLOPS 为 (2I1)×J ( 2 I − 1 ) × J <script id="MathJax-Element-22" type="math/tex">(2I-1) \times J</script> 。

举例:

一个全连接层,输入 100 维,输出 300 维,MACCs 有 300×100=30,000 300 × 100 = 30 , 000 <script id="MathJax-Element-23" type="math/tex">300 \times 100 = 30,000</script> 。不过,如果一个全连接层紧接着卷积层,输入可能没有指定长度 I I <script id="MathJax-Element-24" type="math/tex">I</script> 但有 feature map 的尺寸比如(512, 7, 7)。在 Keras 里就需要写一行 Flatten 把它展平,这样此时的 I I <script id="MathJax-Element-25" type="math/tex">I</script> 就是 512×7×7 512 × 7 × 7 <script id="MathJax-Element-26" type="math/tex">512 \times 7 \times 7</script> 了。

1.2 激活函数

FC 完了接下来通常有个激活函数,ReLU 或者 Sigmoid。激活函数的计算没有点积,所以只用 FLOPS 衡量。

对输出为 J J <script id="MathJax-Element-27" type="math/tex">J</script> 的 FC 层,ReLU 有 J J <script id="MathJax-Element-28" type="math/tex">J</script> FLOPS:

y = max(x, 0)

相比之下 Sigmoid 就复杂很多。

y = 1/(1+exp(-x))

我们把加减乘除、指数、平方根等等运算都算作一次 FLOPS,这里有除法、加法、指数和减法四种运算,所以 FLOPS 就是 J×4 J × 4 <script id="MathJax-Element-29" type="math/tex">J \times 4</script> 。

相对于全连接的矩阵运算,激活函数的计算量通常忽略不计(博主注:不一定,看情况)

1.3 卷积层

卷积层要单独算而不是用全连接层的结论,是因为输入至少是三维的: H×W×C H × W × C <script id="MathJax-Element-30" type="math/tex">H \times W \times C</script> 。对于这样的卷积层,MACCs 有:

K×K×Cin×Hout×Wout×Cout K × K × C i n × H o u t × W o u t × C o u t <script id="MathJax-Element-31" type="math/tex">K \times K \times Cin \times Hout \times Wout \times Cout</script>

解释一下:

  • 输出的 feature map 里每个通道上有 Hout×Wout H o u t × W o u t <script id="MathJax-Element-32" type="math/tex">Hout \times Wout</script> 个元素,
  • 权重以 K×K K × K <script id="MathJax-Element-33" type="math/tex">K \times K</script> 大小的窗口,在所有的 Cin C i n <script id="MathJax-Element-34" type="math/tex">Cin</script> 个通道上做点积,
  • 共有 Cout C o u t <script id="MathJax-Element-35" type="math/tex">Cout</script> 个卷积核,上述操作重复了 Cout C o u t <script id="MathJax-Element-36" type="math/tex">Cout</script> 次

同样,这里也忽略了偏置和激活函数。不应该忽略的是 stride(步长)、dilation factors(漏孔/膨胀卷积)、padding(填充),这就是为什么直接从输出尺寸 Hout×Wout H o u t × W o u t <script id="MathJax-Element-37" type="math/tex">Hout \times Wout</script> 开始算的原因——都已经考虑在内了。

举例:

3×3 3 × 3 <script id="MathJax-Element-38" type="math/tex">3 \times 3</script> 卷积,128 个 filer,输入的 feature map 是 112×112×64 112 × 112 × 64 <script id="MathJax-Element-39" type="math/tex">112 \times 112 \times 64</script> ,stride=1padding=same,MACCs 有:

3×3×64×112×112×128=924,844,032 3 × 3 × 64 × 112 × 112 × 128 = 924 , 844 , 032 <script id="MathJax-Element-40" type="math/tex">3 \times 3 \times 64 \times 112 \times 112 \times 128 = 924, 844, 032</script>

接近十亿的乘-加操作。

1.4 Batch Normalization

计算公式:

z = gamma * (y - mean) / sqrt(variance + epsilon) + beta

首先以输入为卷积层的情况为例。

每个通道上都存在一组 meanbetagammavariance C C <script id="MathJax-Element-41" type="math/tex">C</script> 个通道就有 C×4 C × 4 <script id="MathJax-Element-42" type="math/tex">C \times 4</script> 个可学习的参数。而且 BN 是作用在每一个元素上的,这样看来,造成的 FLOPS 应该不少。

但有趣的是,在 BN 直接连接卷积层的情况下,即 Conv-BN-ReLU 时,通过一组推导,可以将 BN 的计算整合到卷积层当中(注意这是 inference 的情况,跟训练阶段差别很大),从而消去的 BN 层造成的 FLOPS。如果是 Conv-ReLU-BN 的结构这一套就行不通了。

( BN 层的计算结合到 Conv 层中去,BN 层的 FLOPS 消失了,Conv 层需要乘一个常系数)

即从结果上来说,在 inference 时模型中的 BN 层实际被消去了。

1.5 其他层

像 Pooling 层虽然确实很关键,但没有用到点积运算,所以 MACCs 不能很好地衡量这部分计算消耗。如果用 FLOPS,可以取 feature map 的尺寸然后乘一个常系数。

如 maxpooling 层,stride=2filter_sz=2(即输出保持相同尺寸),112 x 112 x 128 的feature map,FLOPS 就是 112 x 112 x 128 = 1,605,632 。相对卷积层和全连接层的运算,这个计算量比较小,所以也可以忽略不计。

RNN 这里不做讨论。简单来说,以 LSTM 为例,计算主要是两个大的矩阵乘法,sigmoid,tanh 和一些元素级的操作。可以看成两个全连接层的运算,所以 MACCs 主要取决于输入、输出和隐状态向量的尺寸。点积运算还是占了大头。

2. 内存占用

内存带宽其实比 MACCs 更重要。目前的计算机结构下,单次内存访问比单次运算慢得多的多。

对每一层网络,设备需要:

  • 从主内存中读取输入向量 / feature map
  • 从主内存中读取权重并计算点积
  • 将输出向量或 feature map 写回主内存

涉及大量的内存访问。内存是很慢的,所以网络层的内存读写对速度有很大的影响,可能比计算耗时还要多。

2.1 权重的内存占用

全连接层有 I x J 大小的权重矩阵,加上偏置向量共计 (I + 1) x J

卷积层的 kernel 通常是正方形的,对 kernel_sz = K 和输入通道为 Cin 、输出 Cout 和额外的 Cout 个偏置的情况,共有 (K x K x Cin + 1) x Cout 个参数。对比之下卷积层的参数量远小于全连接。

举例:

全连接层有4096个输入和4096个输出,所以权重数 (4096+1) x 4096 = 16.8M

3 x 3 、48个卷积核,在64x64 、32个通道的输入上计算,共有 3 x 3 x 32 x 48 + 48 = 13, 872 个权重。

注意到此处卷积层的输入实际是全连接层的32倍(通道),输出是48倍,然鹅权重数只有后者的千分之一不到。全连接层的内存占用真的很可怕。

作者注:卷积层可以看作一个受限连接的全连接层,即权重对 k x k 以外的输入置零,不使用。

2.2 feature maps 和中间结果

CS231n 的 Lesson 9 专门花了很多篇幅讲 feature map 的计算,可以参考。

还是举例说明。卷积层的输入是 224x224x3 ,把所有这些值读出来需要访问 150,528 次内存。如果卷积核是 KxKxCout ,还要乘上这个系数(因为每次卷积都要访问一遍)。拿 stride=2, kernel 数为32的情况来说,输出的 feature map 尺寸为 112x112x32,共计 401,408 次内存访问。

所以,每层的内存访问总数如下:

input = Hin x Win x Cin x K x K x Cout

output = Hout x Wout x Cout

weights = K x K x Cin x Cout + Cout ,按上例:

input = 224 x 224 x 3 x 3 x 3 x 32 = 43,352,064
output = 112 x 112 x 32 = 401,408
weights = 3 x 3 x 3 x 32 + 32 = 896
total = 43,754,368

当网络层数加深时,Hin Win 会越来越小,但通道数会变得很大:

input = 28 x 28 x 256 x 3 x 3 x 512 = 924,844,032
output = 28 x 28 x 512 = 401,408
weights = 3 x 3 x 256 x 512 + 512 = 1,180,160
total = 926,425,600

这种情况下 weights 部分也会变得很大,所以是不能忽略的。

2.3 Fusion

这一节的意思是,像 ReLU 这样比较简单的运算,如果不做优化,在计算时近乎是从输入到输出做了一次拷贝。计算可以认为不耗时间,但内存访问还是有消耗的,所以可以把这一步同卷积层的计算合成,从而节省了一轮内存读写。

3. MobileNet V2 vs. V1

这部分作者讲了他认为 V2 不会比 V1 快的分析过程。结论跟开头博主引的图相近,即乘子都为1.0时,V2是显著快于V1的,但V2在乘子为1.4时速度比V1稍慢。

至于原因嘛,简单来说就是 V2 的层数更深,每层的输入输出参数读写导致内存访问量大增。因此作者认为影响 inference 速度的瓶颈其实不在 MACCs,而是内存访问数(memory accesses)。

V2 with multiplier=1.4 的速度略慢于 V1,但精度高出不少;V2 with multiplier=1.0 速度比 V1 快很多。可以根据需要进行取舍。官方页面上也给了很多实验参考。

然后作者对 VGG16 做了一点考察,结论很有意思。

VGG16 经常被当作图像方面的特征提取器,结构很简单,层数也不多,看起来好像计算比较多、内存访问会少一些,真的是这样吗?对比 MobileNet(输入按移动设备16:9的规格,是126x224,可以算出以下结果:

VGG16 params: 15M
VGG16 MACCs : 8380M
VGG16 MAes  : 8402M

所以更大的 feature map 导致了更多的内存访问。

4 结论

论文中 MobileNet V2 主要比较了 MACCs 和参数量,指出因为这两项规模更小所以速度更快。但实际上还要考虑内存访问的情况。

另外本文给出的 MACCs、内存访问、参数量都是估计值,只用于同类模型的复杂度比较,出了这个语境是毫无意义的。

进一步阅读

论文:

©️2020 CSDN 皮肤主题: 大白 设计师:CSDN官方博客 返回首页