代码源自:
http://www.heibanke.com/2016/10/11/conv_layer_forward_backward/#comments
Convolution Layer Forward
import numpy as np
%load_ext autoreload
%autoreload 2
def conv_forward_naive(x, w, b, conv_param):
"""
A naive implementation of the forward pass for a convolutional layer.
The input consists of N data points, each with C channels, height H and width
W. We convolve each input with F different filters, where each filter spans
all C channels and has height HH and width HH.
Input:
- x: Input data of shape (N, C, H, W)
- w: Filter weights of shape (F, C, HH, WW)
- b: Biases, of shape (F,)
- conv_param: A dictionary with the following keys:
- 'stride': The number of pixels between adjacent receptive fields in the
horizontal and vertical directions.
- 'pad': The number of pixels that will be used to zero-pad the input.
Returns a tuple of:
- out: Output data, of shape (N, F, H', W') where H' and W' are given by
H' = 1 + (H + 2 * pad - HH) / stride
W' = 1 + (W + 2 * pad - WW) / stride
- cache: (x, w, b, conv_param)
"""
out = None
N,C,H,W = x.shape
F,_,HH,WW = w.shape
S = conv_param['stride']
P = conv_param['pad']
Ho = 1 + (H + 2 * P - HH) / S
Wo = 1 + (W + 2 * P - WW) / S
x_pad = np.zeros((N,C,H+2*P,W+2*P))
x_pad[:,:,P:P+H,P:P+W]=x
out = np.zeros((N,F,Ho,Wo))
for f in xrange(F):
for i in xrange(Ho):
for j in xrange(Wo):
out[:,f,i,j] = np.sum(x_pad[:, :, i*S : i*S+HH, j*S : j*S+WW] * w[f, :, :, :], axis=(1, 2, 3))
out[:,f,:,:]+=b[f]
cache = (x, w, b, conv_param)
return out, cache
我们可以用几个例子试试它的输出
x_shape = (2, 3, 4, 4)
w_shape = (2, 3, 3, 3)
x = np.ones(x_shape)
w = np.ones(w_shape)
b = np.array([1,2])
conv_param = {'stride': 1, 'pad': 0}
out, _ = conv_forward_naive(x, w, b, conv_param)
print out
print out.shape
结果如下:
[[[[ 28. 28.] [ 28. 28.]]
[[ 29. 29.] [ 29. 29.]]]
[[[ 28. 28.] [ 28. 28.]]
[[ 29. 29.] [ 29. 29.]]]]
(2, 2, 2, 2)
设置pad为1,自己再计算一下结果。尤其是结果的维数变化。 设置stride为3,pad为1呢? 还可以怎么设置呢?
Convolution Layer Backward
def conv_backward_naive(dout, cache):
"""
A naive implementation of the backward pass for a convolutional layer.
Inputs:
- dout: Upstream derivatives.
- cache: A tuple of (x, w, b, conv_param) as in conv_forward_naive
Returns a tuple of:
- dx: Gradient with respect to x
- dw: Gradient with respect to w
- db: Gradient with respect to b
"""
dx, dw, db = None, None, None
N, F, H1, W1 = dout.shape
x, w, b, conv_param = cache
N, C, H, W = x.shape
HH = w.shape[2]
WW = w.shape[3]
S = conv_param['stride']
P = conv_param['pad']
dx, dw, db = np.zeros_like(x), np.zeros_like(w), np.zeros_like(b)
x_pad = np.pad(x, [(0,0), (0,0), (P,P), (P,P)], 'constant')
dx_pad = np.pad(dx, [(0,0), (0,0), (P,P), (P,P)], 'constant')
db = np.sum(dout, axis=(0,2,3))
for n in xrange(N):
for i in xrange(H1):
for j in xrange(W1):
x_window = x_pad[n, :, i * S : i * S + HH, j * S : j * S + WW]
for f in xrange(F):
dw[f] += x_window * dout[n, f, i, j]
dx_pad[n, :, i * S : i * S + HH, j * S : j * S + WW] += w[f] * dout[n, f, i, j]
dx = dx_pad[:, :, P:P+H, P:P+W]
return dx, dw, db
上面的实现代码是最原始的。 matlab上为了加速,使用已有的conv函数实现上述过程,才有了很多博文上提到的翻转180度两次的过程,翻来翻去的反而不容易理解整个过程。其实卷积层的前向和后向传播,跟信号处理的卷积操作没有直接关系。就是相关和点乘操作。其它实现都是优化加速方法。
我们对反向传播也举个例子
x_shape = (2, 3, 4, 4)
w_shape = (2, 3, 3, 3)
x = np.ones(x_shape)
w = np.ones(w_shape)
b = np.array([1,2])
conv_param = {'stride': 1, 'pad': 0}
Ho = (x_shape[3]+2*conv_param['pad']-w_shape[3])/conv_param['stride']+1
Wo = Ho
dout = np.ones((x_shape[0], w_shape[0], Ho, Wo))
out, cache = conv_forward_naive(x, w, b, conv_param)
dx, dw, db = conv_backward_naive(dout, cache)
print "out shape",out.shape
print "dw=========================="
print dw
print "dx=========================="
print dx
print "db=========================="
print db
out shape (2, 2, 2, 2)
dw==========================
[[[[ 8. 8. 8.] [ 8. 8. 8.] [ 8. 8. 8.]]
[[ 8. 8. 8.] [ 8. 8. 8.] [ 8. 8. 8.]]
[[ 8. 8. 8.] [ 8. 8. 8.] [ 8. 8. 8.]]]
[[[ 8. 8. 8.] [ 8. 8. 8.] [ 8. 8. 8.]]
[[ 8. 8. 8.] [ 8. 8. 8.] [ 8. 8. 8.]]
[[ 8. 8. 8.] [ 8. 8. 8.] [ 8. 8. 8.]]]]
dx==========================
[[[[ 2. 4. 4. 2.] [ 4. 8. 8. 4.] [ 4. 8. 8. 4.] [ 2. 4. 4. 2.]]
[[ 2. 4. 4. 2.] [ 4. 8. 8. 4.] [ 4. 8. 8. 4.] [ 2. 4. 4. 2.]]
[[ 2. 4. 4. 2.] [ 4. 8. 8. 4.] [ 4. 8. 8. 4.] [ 2. 4. 4. 2.]]]
[[[ 2. 4. 4. 2.] [ 4. 8. 8. 4.] [ 4. 8. 8. 4.] [ 2. 4. 4. 2.]]
[[ 2. 4. 4. 2.] [ 4. 8. 8. 4.] [ 4. 8. 8. 4.] [ 2. 4. 4. 2.]]
[[ 2. 4. 4. 2.] [ 4. 8. 8. 4.] [ 4. 8. 8. 4.] [ 2. 4. 4. 2.]]]]
db==========================
[ 8. 8.]