前言
在我的上一篇文章「 看完这篇我就不信还有人不懂卷积神经网络!」中,有位掘友在评论区留言道:要是能讲讲 img2col 就好了。今天我就安排上了,于是就有了本文。
温馨提示:本文需要一些卷积神经网络方面的基础知识,不了解的同学需要先看看上面这篇文章☝。
简介
img2col 是一种实现卷积操作的加速计算策略。它能将卷积操作转化为 GEMM,从而最大化地缩短卷积计算的时间。
GEMM 是通用矩阵乘 (General Matrix Multiply) 的英文缩写,其实就是一般意义上的矩阵乘法,数学表达就是 C = A x B。根据上下文语境,GEMM 有时也指实现矩阵乘法的函数接口。
为什么要将卷积操作转化为 GEMM 呢?
-
因为线性代数领域已经有非常成熟的计算接口(BLAS,Fortran 语言实现)来高效地实现大型的矩阵乘法,几乎可以做到极限优化。
-
将卷积过程中用到的所有特征子矩阵整合成一个大型矩阵存放在连续的内存中,虽然增加了存储成本,但是减少了内存访问的次数,从而缩短了计算时间。
原理
img2col 的原理可以用下面这一张图来概括:
此图☝出自这篇论文: High Performance Convolutional Neural Networks for Document Processing,感兴趣的同学可以自行去拜读一下。
上面这张图还是有些抽象,下面我们一步一步来分解上面这张图:
1. Input Features -> Input Matrix
不难看出,输入特征图一共有三个通道,我们以不同的颜色来区分。
以蓝色的特征图为例,它是一个 3 x 3 的矩阵,而卷积核是一个 2 x 2 的矩阵,当卷积核的滑动步长为 1 时,那么传统的直接卷积计算一共需要进行 4 次卷积核与对应特征子矩阵之间的点积运算。
现在我们把每一个特征子矩阵都排列成一个行向量(如图中编号1️⃣、2️⃣所示),然后把这 4 个行向量堆叠成一个新的矩阵,就得到了蓝色特征图所对应的 Input Matrix。
当输入特征图不止一个通道时,则对每一个通道的特征图都采用上述操作,然后再把每一个通道对应的 Input Matrix 堆叠成一个完整的 Input Matrix。
2. Convolution Kernel -> Kernel Matrix
不难看出,卷积核一共有两个,每个均为三通道,我们以第一个卷积核为例进行讲解。
将卷积核转化成矩阵的方式和第一步有些类似,只是这里应该转化成列向量(如图中编号1️⃣、2️⃣、3️⃣所示)。如果第一步转化成列向量,则这里应该转化成行向量,这是由矩阵乘法的计算特性决定的,即一个矩阵的每一行和另一个矩阵的每一列做内积,所以特征图和卷积核只能一个展开为行,一个展开为列。
同样地,如果卷积核有多个通道,则对每一个通道的卷积核都采用上述操作,然后再把每一个通道对应的 Kernel Matrix 堆叠成一个完整的 Kernel Matrix。
3. Input Matrix * Kernel Matrix = Output Matrix
在得到上述两个矩阵之后,接下来调用 GEMM 函数接口进行矩阵乘法运算即可得到输出矩阵,然后将输出矩阵通过 col2img 函数就可以得到和卷积运算一样的输出特征图。
结语
通过 img2col 函数,我们只需执行一次矩阵乘法计算就能得到与卷积运算相同的结果,而传统的直接卷积计算光是一个通道就需要进行 4 次(仅指本例中)卷积核与对应特征子矩阵之间的点积运算,那么如果通道数特别多?输入特征图非常庞大呢?那计算的次数将是成倍增长的!
有些同学可能会担心将所有特征子矩阵都堆叠到一个矩阵中,会不会导致内存不够用或者计算速度非常慢,尤其是在深度神经网络中。其实不用担心,因为矩阵的存储和计算其实都是非常规则的,很容易通过分布式和并行的方式来解决,感兴趣的同学可以自行阅读相关论文。
代码实现:
# encoding: utf-8
import numpy as np
def img2col(image: np.ndarray, kernel_size: int, stride: int = 1, padding: int = 0) -> np.ndarray:
h, w = image.shape
out_h = (h + 2 * padding - kernel_size) // stride + 1
out_w = (w + 2 * padding - kernel_size) // stride + 1
pad_image = np.zeros((h + 2 * padding, w + 2 * padding))
pad_image[padding:h+padding, padding:w+padding] = image
ph, pw = pad_image.shape
res = np.zeros((out_h * out_w, kernel_size * kernel_size))
idx = 0
# x * (k^2) * k^2
for i in range(0, ph - kernel_size + 1, stride):
for j in range(0, pw - kernel_size + 1, stride):
res[idx] = pad_image[i:i + kernel_size, j:j + kernel_size].reshape(-1)
idx += 1
return res
img = np.zeros((6,3))
print(img)
print(img2col(img, 2))
再来一个Pytorch版本的:
# encoding: utf-8
import torch
def img2col(image: torch.Tensor, kernel_size: int, stride: int = 1, padding: int = 0) -> torch.Tensor:
# 将变成(h,w)的Tensor转换成(out_h*out_w,kernal_size*kernal_size)的Tensor,其中out_h = (h + 2 * padding - kernel_size) // stride + 1
in_h, in_w = image.shape
out_h = (in_h + 2 * padding - kernel_size) // stride + 1
out_w = (in_w + 2 * padding - kernel_size) // stride + 1
print('out_w={} out_h={}'.format(out_w, out_h))
pad_image = torch.zeros((in_h + 2 * padding, in_w + 2 * padding))
pad_image[padding:in_h+padding, padding:in_w+padding] = image
ph, pw = pad_image.shape
res = torch.zeros((out_h * out_w, kernel_size * kernel_size))
idx = 0
# x * (k^2) * k^2
for i in range(0, ph - kernel_size + 1, stride):
for j in range(0, pw - kernel_size + 1, stride):
res[idx] = pad_image[i:i + kernel_size, j:j + kernel_size].reshape(-1)
idx += 1
return res
img = torch.rand((6,3))
print('raw_img={}'.format(img))
# print(img2col(img, 2).shape)
print('img2col={}'.format(img2col(img, kernel_size=2, padding=1, stride=1)))
作者:Codeman
来源:稀土掘金
著作权归作者所有。商业转载请联系作者获得授权,非商业转载请注明出处。