im2col
最近在研究caffe的源码,im2col是caffe非常重要的部分。总结来说,就是把卷积运算变成矩阵运算,具体怎么做到的,下面通过源码还有一个例子来详细说明一下。
源码
首先,先看一下caffe的作者贾大神的ppt,是从知乎上找到的。
C:图像的通道,K:卷积核的大小,H:图像的高,W:图像的宽
这是比较形象的转换,但个人觉得有个错误,就是Feature Matrix的维度应该不是(H*W)* (C*K*K)。首先在代码中,它存储的是上图中矩阵的转置。即Feature Matrix的高为C*K*K,它的宽应该和卷积后输出的特征图的维度有关,应该是输出特征图的高*输出特征图的宽。
解析下源码:
template <typename Dtype>
void im2col_cpu(const Dtype* data_im, const int channels,
const int height, const int width, const int kernel_h, const int kernel_w,
const int pad_h, const int pad_w,
const int stride_h, const int stride_w,
const int dilation_h, const int dilation_w,
Dtype* data_col) {
// 计算输出特征图的高和宽
const int output_h = (height + 2 * pad_h -
(dilation_h * (kernel_h - 1) + 1)) / stride_h + 1;
const int output_w = (width + 2 * pad_w -
(dilation_w * (kernel_w - 1) + 1)) / stride_w + 1;
// 计算一个通道的数据量
const int channel_size = height * width;
// 第一个for遍历每个通道,data_im表示输入数据指针移动一个通道的数据量
for (int channel = channels; channel--; data_im += channel_size) {
// 第二个和第三个for遍历Kernel的每个位置,按顺序从左到右,从上到下
for (int kernel_row = 0; kernel_row < kernel_h; kernel_row++) {
for (int kernel_col = 0; kernel_col < kernel_w; kernel_col++) {
// kernel的一个位置在图像中的行数
int input_row = -pad_h + kernel_row * dilation_h;
// 第四个和第五个for遍历data_col的一行,一次填充相应的数
for (int output_rows = output_h; output_rows; output_rows--) {
if (!is_a_ge_zero_and_a_lt_b(input_row, height)) {
for (int output_cols = output_w; output_cols; output_cols--) {
*(data_col++) = 0;
}
} else {
int input_col = -pad_w + kernel_col * dilation_w;
for (int output_col = output_w; output_col; output_col--) {
if (is_a_ge_zero_and_a_lt_b(input_col, width)) {
*(data_col++) = data_im[input_row * width + input_col];
} else {
*(data_col++) = 0;
}
input_col += stride_w;
}
}
input_row += stride_h;
}
}
}
}
}
在这里,我要说一下自己的理解,因为这个代码,我看了好久,网上查了很多资料,都很迷惑,突然想通一个点,就全通了。
首先,从图片格式变成矩阵格式,我们也从上图中了解到,矩阵data_col的高是:C*K*K,即filter的数据量,矩阵data_col的宽是:输出特征图的高* 输出矩阵的宽。
代码中的具体操作是:往data_col中填数,顺序是从左到右,从上到下,依次填入。
那么data_col中第一行的数,应该是图像中与卷积核(0,0,0)坐标位置相乘的所有数字。表达的可能不是很清楚,我做了一个图,看了应该很明白。
例子:
我们假设是2维卷积操作,图像大小为3*3,padding是1,卷积核大小为3*3,stride为1.
可能看的不是很清晰,但是表达应该很明白。卷积核遍历所有位置,首先把蓝色部分即卷积核的(0,0)依次填入data_col的第一行,然后卷积核的(0, 1)位置....,直到data_col填满为止。