假设我们现在已经可以递推出上层的梯度误差δl+1
δ
l
+
1
了;卷积层输出z,输入a和W,b的关系为:
zl=al∗Wl+b
z
l
=
a
l
∗
W
l
+
b
因此本层残差(
δl
δ
l
)和上层残差(
δl+1
δ
l
+
1
)的递推关系为:
伪码:δl=∂J(W,b)∂al=∂J(W,b)∂zl∂zl∂al=δl+1∗Wl
伪
码
:
δ
l
=
∂
J
(
W
,
b
)
∂
a
l
=
∂
J
(
W
,
b
)
∂
z
l
∂
z
l
∂
a
l
=
δ
l
+
1
∗
W
l
上边伪码是为了便于推导理解,事实上公式是:
δl=∂J(W,b)∂al=∂J(W,b)∂zl∂zl∂al=pad(δl+1)⊗rot180(Wl)
δ
l
=
∂
J
(
W
,
b
)
∂
a
l
=
∂
J
(
W
,
b
)
∂
z
l
∂
z
l
∂
a
l
=
p
a
d
(
δ
l
+
1
)
⊗
r
o
t
180
(
W
l
)
⎛⎝⎜⎜⎜00000a11a2100a12a2200000⎞⎠⎟⎟⎟⊗⎛⎝⎜w11w21w31w12w22w32w13w23w33⎞⎠⎟=(z11z21z12z22)
(
0
0
0
0
0
a
11
a
12
0
0
a
21
a
22
0
0
0
0
0
)
⊗
(
w
11
w
12
w
13
w
21
w
22
w
23
w
31
w
32
w
33
)
=
(
z
11
z
12
z
21
z
22
)
反向传播的
z
z
的梯度误差δl+1是:
(δ11δ21δ12δ22)
(
δ
11
δ
12
δ
21
δ
22
)
利用卷积的定义,很容易得出:
z11=0+0+0+0+a11w22+a12w23+0+a21w32+a22w33z12=0+0+0+a11w21+a12w22+0+a21w31+a22w32+0z21=0+a11w12+a12w13+0+a21w22+a22w23+0+0+0z22=a11w11+a12w12+0+a21w21+a22w22+0+0+0+0
z
11
=
0
+
0
+
0
+
0
+
a
11
w
22
+
a
12
w
23
+
0
+
a
21
w
32
+
a
22
w
33
z
12
=
0
+
0
+
0
+
a
11
w
21
+
a
12
w
22
+
0
+
a
21
w
31
+
a
22
w
32
+
0
z
21
=
0
+
a
11
w
12
+
a
12
w
13
+
0
+
a
21
w
22
+
a
22
w
23
+
0
+
0
+
0
z
22
=
a
11
w
11
+
a
12
w
12
+
0
+
a
21
w
21
+
a
22
w
22
+
0
+
0
+
0
+
0
那么根据上面的式子,我们有:
∂J(W,b)∂al11=w22δ11+w21δ12+w12δ21+w11δ22∂J(W,b)∂al12=w23δ11+w22δ12+w13δ21+w12δ22∂J(W,b)∂al21=w32δ11+w31δ12+w22δ21+w21δ22∂J(W,b)∂al22=w33δ11+w32δ12+w23δ21+w22δ22
∂
J
(
W
,
b
)
∂
a
11
l
=
w
22
δ
11
+
w
21
δ
12
+
w
12
δ
21
+
w
11
δ
22
∂
J
(
W
,
b
)
∂
a
12
l
=
w
23
δ
11
+
w
22
δ
12
+
w
13
δ
21
+
w
12
δ
22
∂
J
(
W
,
b
)
∂
a
21
l
=
w
32
δ
11
+
w
31
δ
12
+
w
22
δ
21
+
w
21
δ
22
∂
J
(
W
,
b
)
∂
a
22
l
=
w
33
δ
11
+
w
32
δ
12
+
w
23
δ
21
+
w
22
δ
22
最终我们可以一共得到4个式子。整理成矩阵形式后可得:
∂J(W,b)∂al=⎛⎝⎜⎜⎜00000δ11δ2100δ12δ2200000⎞⎠⎟⎟⎟⊗⎛⎝⎜w33w23w13w32w22w12w31w21w11⎞⎠⎟
∂
J
(
W
,
b
)
∂
a
l
=
(
0
0
0
0
0
δ
11
δ
12
0
0
δ
21
δ
22
0
0
0
0
0
)
⊗
(
w
33
w
32
w
31
w
23
w
22
w
21
w
13
w
12
w
11
)
从这个例子证明了刚才的公式的正确性:
δl=∂J(W,b)∂al=∂J(W,b)∂zl∂zl∂al=pad(δl+1)⊗rot180(Wl)
δ
l
=
∂
J
(
W
,
b
)
∂
a
l
=
∂
J
(
W
,
b
)
∂
z
l
∂
z
l
∂
a
l
=
p
a
d
(
δ
l
+
1
)
⊗
r
o
t
180
(
W
l
)
∂J(W,b)∂al=pool(d)=pool(⎛⎝⎜⎜⎜d11d21d31d41d12d22d32d42d13d23d33d43d14d24d34d44⎞⎠⎟⎟⎟)=(d22d32d23d33)
∂
J
(
W
,
b
)
∂
a
l
=
p
o
o
l
(
d
)
=
p
o
o
l
(
(
d
11
d
12
d
13
d
14
d
21
d
22
d
23
d
24
d
31
d
32
d
33
d
34
d
41
d
42
d
43
d
44
)
)
=
(
d
22
d
23
d
32
d
33
)
注:这里pool()池化的意思,在这里表示去掉d的上下左右各pad=1的数字,剩下的部分。
d11=⎛⎝⎜d11d21d31d12d22d32d13d23d33⎞⎠⎟=⎛⎝⎜w11w21w31w12w22w32w13w23w33⎞⎠⎟∗δ11=⎛⎝⎜w11∗δ11w21∗δ11w31∗δ11w12∗δ11w22∗δ11w32∗δ11w13∗δ11w23∗δ11w33∗δ11⎞⎠⎟
d
11
=
(
d
11
d
12
d
13
d
21
d
22
d
23
d
31
d
32
d
33
)
=
(
w
11
w
12
w
13
w
21
w
22
w
23
w
31
w
32
w
33
)
∗
δ
11
=
(
w
11
∗
δ
11
w
12
∗
δ
11
w
13
∗
δ
11
w
21
∗
δ
11
w
22
∗
δ
11
w
23
∗
δ
11
w
31
∗
δ
11
w
32
∗
δ
11
w
33
∗
δ
11
)
d12=⎛⎝⎜d12d22d32d13d23d33d14d24d34⎞⎠⎟=⎛⎝⎜w11w21w31w12w22w32w13w23w33⎞⎠⎟∗δ12=⎛⎝⎜w11∗δ12w21∗δ12w31∗δ12w12∗δ12w22∗δ12w32∗δ12w13∗δ12w23∗δ12w33∗δ12⎞⎠⎟
d
12
=
(
d
12
d
13
d
14
d
22
d
23
d
24
d
32
d
33
d
34
)
=
(
w
11
w
12
w
13
w
21
w
22
w
23
w
31
w
32
w
33
)
∗
δ
12
=
(
w
11
∗
δ
12
w
12
∗
δ
12
w
13
∗
δ
12
w
21
∗
δ
12
w
22
∗
δ
12
w
23
∗
δ
12
w
31
∗
δ
12
w
32
∗
δ
12
w
33
∗
δ
12
)
d21=⎛⎝⎜d21d31d41d22d32d42d23d33d43⎞⎠⎟=⎛⎝⎜w11w21w31w12w22w32w13w23w33⎞⎠⎟∗δ21=⎛⎝⎜w11∗δ21w21∗δ21w31∗δ21w12∗δ21w22∗δ21w32∗δ21w13∗δ21w23∗δ21w33∗δ21⎞⎠⎟
d
21
=
(
d
21
d
22
d
23
d
31
d
32
d
33
d
41
d
42
d
43
)
=
(
w
11
w
12
w
13
w
21
w
22
w
23
w
31
w
32
w
33
)
∗
δ
21
=
(
w
11
∗
δ
21
w
12
∗
δ
21
w
13
∗
δ
21
w
21
∗
δ
21
w
22
∗
δ
21
w
23
∗
δ
21
w
31
∗
δ
21
w
32
∗
δ
21
w
33
∗
δ
21
)
d22=⎛⎝⎜d22d32d42d23d33d43d24d34d44⎞⎠⎟=⎛⎝⎜w11w21w31w12w22w32w13w23w33⎞⎠⎟∗δ22=⎛⎝⎜w11∗δ22w21∗δ22w31∗δ22w12∗δ22w22∗δ22w32∗δ22w13∗δ22w23∗δ22w33∗δ22⎞⎠⎟
d
22
=
(
d
22
d
23
d
24
d
32
d
33
d
34
d
42
d
43
d
44
)
=
(
w
11
w
12
w
13
w
21
w
22
w
23
w
31
w
32
w
33
)
∗
δ
22
=
(
w
11
∗
δ
22
w
12
∗
δ
22
w
13
∗
δ
22
w
21
∗
δ
22
w
22
∗
δ
22
w
23
∗
δ
22
w
31
∗
δ
22
w
32
∗
δ
22
w
33
∗
δ
22
)
然后,将d11,d12,d21,d22
d
11
,
d
12
,
d
21
,
d
22
中相应di,j
d
i
,
j
的位置相加,得到:
d11=w11∗δ11
d
11
=
w
11
∗
δ
11
...
.
.
.
d22=w22∗δ11+w21∗δ12+w12∗δ21+w11∗δ22
d
22
=
w
22
∗
δ
11
+
w
21
∗
δ
12
+
w
12
∗
δ
21
+
w
11
∗
δ
22
d23=w23∗δ11+w22∗δ12+w13∗δ21+w12∗δ22
d
23
=
w
23
∗
δ
11
+
w
22
∗
δ
12
+
w
13
∗
δ
21
+
w
12
∗
δ
22
d32=w32∗δ11+w31∗δ12+w22∗δ21+w21∗δ22
d
32
=
w
32
∗
δ
11
+
w
31
∗
δ
12
+
w
22
∗
δ
21
+
w
21
∗
δ
22
d23=w33∗δ11+w32∗δ12+w23∗δ21+w22∗δ22
d
23
=
w
33
∗
δ
11
+
w
32
∗
δ
12
+
w
23
∗
δ
21
+
w
22
∗
δ
22
...
.
.
.
d44=w33∗δ22
d
44
=
w
33
∗
δ
22
可以看出,我们的计算结果与第一个公式一样:
∂J(W,b)∂al=⎛⎝⎜∂J(W,b)∂al11∂J(W,b)∂al21∂J(W,b)∂al12∂J(W,b)∂al22⎞⎠⎟=(d22d32d23d33)
∂
J
(
W
,
b
)
∂
a
l
=
(
∂
J
(
W
,
b
)
∂
a
11
l
∂
J
(
W
,
b
)
∂
a
12
l
∂
J
(
W
,
b
)
∂
a
21
l
∂
J
(
W
,
b
)
∂
a
22
l
)
=
(
d
22
d
23
d
32
d
33
)
注意:计算方法二不需要rot180(w)
2.1.代码:
for i in range(H_out):
for j in range(W_out):
x_pad_masked = x_pad[:, :, i * stride:i * stride + HH, j * stride:j * stride + WW]
for k in range(F): # compute dw
dw[k, :, :, :] += np.sum(x_pad_masked * (residual[:, k, i, j])[:, None, None, None], axis=0)
# dw=pad(bottom_data)* top_diff_ijfor n in range(N): # compute dx_pad
dx_pad[n, :, i * stride:i * stride + HH, j * stride:j * stride + WW] += np.sum((self.w[:, :, :, :] * (residual[n, :, i,j])[:, None, None, None]), axis=0)
# dx = (w)* (top_diff_ij)