对于图像的卷积操作实现中,现有框架都支持设置 mini-batch size,此时一个 mini-batch 中的图像信息为 NHWC(按照 batch 维度放最前)。本文中暂不考虑 mini-batch 的情况,按照输入图像信息为 HWC、N 个大小为 K 的 filter 计算基本卷积,过程中不考虑 dilate 等情况。
首先说明卷积过程时,一般采用下图类似的方式(图片来源于此),更方便理解。filter 和 input 的通道数一致均为 C,每个 filter (K × K × C)和 input 中对应的 patch (K × K × C)做对点相乘求和后得到output 中的一个点,output 的通道数由 filter 的数量 N 决定,output 单个通道的宽高由 input 和 kernel 的尺寸共同决定。
但是实际计算中,借助已经较为成熟的矩阵运算来计算卷积会更快一些,可以参考下图的 “
W
X
=
Y
WX = Y
WX=Y” 。
- 将每个 filter(K × K × C) 拉平展开为 1 × K × K × C 的向量,将所有 filter 堆叠起来形成 NKKC 的权重矩阵;
- 在 input 的单通道中复制和 filter size 大小一致的 patch 纵向拉平为 KK × 1 的向量,复制之后的 patch 横向堆叠,形成 KK × HW 的矩阵;
- 对 input 的每个通道执行上一步,并且将各个通道的结果纵向堆叠,形成 KKC × HW 的 input 信息矩阵。
- 矩阵乘法,得到最后结果,可以 reshape 到 N × H × W 作为输出。
coding 时,不考虑非法输入的情况,需要考虑:input,filter 以及 output 的维度定义具体是什么?这涉及到后面计算中矩阵元素位置对应,所以要想清楚。暂时定义:
- input:H × W × C;
- filters:N × K × K;
- output:outH × outW × N;
- patch:K × K × C。
接下来考虑以下几个步骤,总原则坚持 filter2row 一行一个 filter,img2col 一列一个 patch:
(1)filter2row:filters 的展开,每行对应 1 个filter 信息,filter2row:N × KKC;
(2)pad_img:根据 padding 的模式来对 input 进行 pad 得到新的 input,仍然命名为 input:H × W × C;
(3)cal output shape:根据 kernel_size,padding 模式以及 stride 的情况,计算 outH 和 outW;
(4)fetch_patch:按照步长 stride 从 input 中拿 patch 出来,注意边上不完整的需要舍弃,patch 拿出来维度是 K × K × C,注意 transpose 再 reshape 到一列,列长为 KKC;
(5)img2col:将 patch 放在一起生成 img2col:KKC × outHoutW
(6)matmul:filter2row 和 img2col 进行矩阵乘法得到 KKC × outHoutW,然后注意 transpose 变换再 reshape 到 outH × outW × N 的维度。
如果加入 mini-batch 的话,图像等大按照通道叠加在一起比较方便,我觉得在 filter2row 基础上行 repeat batch 份,img2col 纵向上变成 KKC batch。
import numpy as np
import cv2
"""
kernels: outC, k, k, inC
input: inH, inW, inC
output: outH, outW, outC
kernel2row: outC, k * k * inC
img2col: k * k * inC, outH * outW
padding
stride
"""
class Conv2D:
def __init__(self, filters, input_img, padding='same', stride=2):
self.inC = input_img.shape[2]
self.filters = filters
self.padding = padding
self.stride = stride
self.input_img = self.pad_img(input_img)
self.filter_row = self.filters2row()
self.outH, self.outW = self.cal_out_shape()
self.img_col = self.img2col()
self.output_img = self.matmul()
return
def pad_img(self, input_img):
assert self.padding in ['same', 'valid']
H, W, C = input_img.shape
K = self.filters.shape[1]
p = K // 2
if self.padding == 'same':
img_padded = np.zeros((H + K - 1, W + K - 1, C), dtype=float)
img_padded[p: H + p, p: W + p, :] = input_img
return img_padded
elif self.padding == 'valid':
return input_img
def filters2row(self):
if len(self.filters.shape) == 3:
self.filters = np.repeat(np.expand_dims(self.filters, axis=3), self.inC, axis=3)
self.outC, k, k, inC = self.filters.shape
self.filter_size = k
filter_row = np.transpose(self.filters, (0, 3, 1, 2))
filter_row = np.reshape(filter_row, (self.outC, k * k * inC))
return filter_row
def img2col(self):
H, W, inC = self.input_img.shape
self.img_col = np.zeros((self.filter_size ** 2 * inC, self.outH * self.outW), dtype=float)
for i in range(0, H, self.stride):
if i + self.stride >= H: break
for j in range(0, W, self.stride):
if j + self.stride >= W: break
patch = self.input_img[i: i + self.filter_size, j: j + self.filter_size, :]
patch = self.patch2col(patch)
self.img_col[:, i // self.stride * self.outW + j // self.stride] = patch
return self.img_col
def cal_out_shape(self):
self.inH, self.inW = self.input_img.shape[0:2]
self.outH = (self.inH - self.filter_size) // self.stride + 1
self.outW = (self.inW - self.filter_size) // self.stride + 1
return self.outH, self.outW
def patch2col(self, patch):
k, k, inC = patch.shape
patch = np.transpose(patch, (2, 0, 1))
patch = np.reshape(patch, (k * k * inC))
return patch
def matmul(self):
output = np.matmul(self.filter_row, self.img_col)
output = np.transpose(output, (1, 0))
output = np.reshape(output, (self.outH, self.outW, self.outC))
return output
filters = [[[-1, 0, 1], [-1, 0, 1], [-1, 0, 1]],
[[1, 0, -1], [1, 0, -1], [1, 0, -1]],
[[1, 1, 1], [0, 0, 0], [-1, -1, -1]],
[[-1, -1, -1], [0, 0, 0], [1, 1, 1]]]
filters = np.array(filters)
img = np.zeros((100, 100, 3), dtype=float)
img[25: 75, 25: 75, :] = 1
Conv2D = Conv2D(filters, img, padding='same', stride=2)
output = Conv2D.output_img
cv2.imshow('orig', img)
print(filters.shape)
print(output.shape)
cv2.imshow('c0', output[:, :, 0])
cv2.imshow('c1', output[:, :, 1])
cv2.imshow('c2', output[:, :, 2])
cv2.imshow('c3', output[:, :, 3])
cv2.waitKey()