【Python实现卷积神经网络】:反向传播推导卷积层对输入数据的求导

0.前言

  通过之前的学习【Python实现卷积神经网络】:卷积层的正向传播与反向传播+python实现代码,我们知道卷积层的反向传播有三个梯度要求:

1.对输入数据的求导
这里写图片描述
2.对W的求导
这里写图片描述
3.对b的求导
这里写图片描述


  这篇博客推导第一个公式:对输入数据求导。如下公式是怎么来的:
这里写图片描述

1.【对输入数据求导】计算方法一

  我在之前的博客中举了正向传播输入数据不带pad,它的反向传播对输入数据求导的例子。

  这里我们通过举另外一个输入数据带pad的正向卷积,然后反向传播的例子。

  假设我们现在已经可以递推出上层的梯度误差 δl+1 δ l + 1 了;卷积层输出z,输入a和W,b的关系为:

zl=alWl+b z l = a l ∗ W l + b

  因此本层残差( δl δ l )和上层残差( δl+1 δ l + 1 )的递推关系为:

δl=J(W,b)al=J(W,b)zlzlal=δl+1Wl 伪 码 : δ l = ∂ J ( W , b ) ∂ a l = ∂ J ( W , b ) ∂ z l ∂ z l ∂ a l = δ l + 1 ∗ W l

  上边伪码是为了便于推导理解,事实上公式是:

δl=J(W,b)al=J(W,b)zlzlal=pad(δl+1)rot180(Wl) δ l = ∂ J ( W , b ) ∂ a l = ∂ J ( W , b ) ∂ z l ∂ z l ∂ a l = p a d ( δ l + 1 ) ⊗ r o t 180 ( W l )

  假设我们输入a是2x2的矩阵,且加入pad=1卷积核W是3x3的矩阵,输出z是2x2的矩阵,那么反向传播的z的梯度误差δ也是2x2的矩阵。我们列出a,W,z的矩阵表达式如下:
  

00000a11a2100a12a2200000w11w21w31w12w22w32w13w23w33=(z11z21z12z22) ( 0 0 0 0 0 a 11 a 12 0 0 a 21 a 22 0 0 0 0 0 ) ⊗ ( w 11 w 12 w 13 w 21 w 22 w 23 w 31 w 32 w 33 ) = ( z 11 z 12 z 21 z 22 )

  反向传播的 z z 的梯度误差δl+1是:
  
(δ11δ21δ12δ22) ( δ 11 δ 12 δ 21 δ 22 )

  利用卷积的定义,很容易得出:
z11=0+0+0+0+a11w22+a12w23+0+a21w32+a22w33z12=0+0+0+a11w21+a12w22+0+a21w31+a22w32+0z21=0+a11w12+a12w13+0+a21w22+a22w23+0+0+0z22=a11w11+a12w12+0+a21w21+a22w22+0+0+0+0 z 11 = 0 + 0 + 0 + 0 + a 11 w 22 + a 12 w 23 + 0 + a 21 w 32 + a 22 w 33 z 12 = 0 + 0 + 0 + a 11 w 21 + a 12 w 22 + 0 + a 21 w 31 + a 22 w 32 + 0 z 21 = 0 + a 11 w 12 + a 12 w 13 + 0 + a 21 w 22 + a 22 w 23 + 0 + 0 + 0 z 22 = a 11 w 11 + a 12 w 12 + 0 + a 21 w 21 + a 22 w 22 + 0 + 0 + 0 + 0

  那么根据上面的式子,我们有:

J(W,b)al11=w22δ11+w21δ12+w12δ21+w11δ22J(W,b)al12=w23δ11+w22δ12+w13δ21+w12δ22J(W,b)al21=w32δ11+w31δ12+w22δ21+w21δ22J(W,b)al22=w33δ11+w32δ12+w23δ21+w22δ22 ∂ J ( W , b ) ∂ a 11 l = w 22 δ 11 + w 21 δ 12 + w 12 δ 21 + w 11 δ 22 ∂ J ( W , b ) ∂ a 12 l = w 23 δ 11 + w 22 δ 12 + w 13 δ 21 + w 12 δ 22 ∂ J ( W , b ) ∂ a 21 l = w 32 δ 11 + w 31 δ 12 + w 22 δ 21 + w 21 δ 22 ∂ J ( W , b ) ∂ a 22 l = w 33 δ 11 + w 32 δ 12 + w 23 δ 21 + w 22 δ 22

  最终我们可以一共得到4个式子。整理成矩阵形式后可得:
J(W,b)al=00000δ11δ2100δ12δ2200000w33w23w13w32w22w12w31w21w11 ∂ J ( W , b ) ∂ a l = ( 0 0 0 0 0 δ 11 δ 12 0 0 δ 21 δ 22 0 0 0 0 0 ) ⊗ ( w 33 w 32 w 31 w 23 w 22 w 21 w 13 w 12 w 11 )

  从这个例子证明了刚才的公式的正确性:
  
δl=J(W,b)al=J(W,b)zlzlal=pad(δl+1)rot180(Wl) δ l = ∂ J ( W , b ) ∂ a l = ∂ J ( W , b ) ∂ z l ∂ z l ∂ a l = p a d ( δ l + 1 ) ⊗ r o t 180 ( W l )

 

当然,这个仅仅是对输入数据求导的计算公式1,如果我们有别的计算方法能够得出同样的结果,那么我们也可以总结为对输入数据求导的计算公式2。有没有呢?当然有,稍后再表。

1.1.代码

residual_pad = np.pad(residual, ((0,), (0,), (pad_diff_H,), (pad_diff_W,)), mode='constant', constant_values=0)
for i in range(H_out):
    for j in range(W_out):
        residual_pad_masked = residual_pad[:, :, i*stride:i*stride+HH, j*stride:j*stride+WW]        
        for h in range(C):
            dx_2[:, h , i, j] = np.sum(residual_pad_masked[:,:,:,:] * rot_w[:, h, :, :], axis=(1,2,3))

注意:
这里的pad大小是由正向传播卷积核与正向传播pad共同决定的,不是1。我总结的公式是:

paddiff=kernelsize(1+padfoward) p a d d i f f = k e r n e l s i z e − ( 1 + p a d f o w a r d )

至于这个公式是怎么来的,请读者将上边儿我举的例子中3X3的核变成5X5或者7X7的核,然后推导一边就总结出来了。

2.【对输入数据求导】计算方法二

  还是上边儿的例子,这次我们换种计算方法,看最终结果和上边最终结果一样不。
  
  我们假设:
  

J(W,b)al=pool(d)=pool(d11d21d31d41d12d22d32d42d13d23d33d43d14d24d34d44)=(d22d32d23d33) ∂ J ( W , b ) ∂ a l = p o o l ( d ) = p o o l ( ( d 11 d 12 d 13 d 14 d 21 d 22 d 23 d 24 d 31 d 32 d 33 d 34 d 41 d 42 d 43 d 44 ) ) = ( d 22 d 23 d 32 d 33 )

注:这里pool()池化的意思,在这里表示去掉d的上下左右各pad=1的数字,剩下的部分。
  

d11=d11d21d31d12d22d32d13d23d33=w11w21w31w12w22w32w13w23w33δ11=w11δ11w21δ11w31δ11w12δ11w22δ11w32δ11w13δ11w23δ11w33δ11 d 11 = ( d 11 d 12 d 13 d 21 d 22 d 23 d 31 d 32 d 33 ) = ( w 11 w 12 w 13 w 21 w 22 w 23 w 31 w 32 w 33 ) ∗ δ 11 = ( w 11 ∗ δ 11 w 12 ∗ δ 11 w 13 ∗ δ 11 w 21 ∗ δ 11 w 22 ∗ δ 11 w 23 ∗ δ 11 w 31 ∗ δ 11 w 32 ∗ δ 11 w 33 ∗ δ 11 )


d12=d12d22d32d13d23d33d14d24d34=w11w21w31w12w22w32w13w23w33δ12=w11δ12w21δ12w31δ12w12δ12w22δ12w32δ12w13δ12w23δ12w33δ12 d 12 = ( d 12 d 13 d 14 d 22 d 23 d 24 d 32 d 33 d 34 ) = ( w 11 w 12 w 13 w 21 w 22 w 23 w 31 w 32 w 33 ) ∗ δ 12 = ( w 11 ∗ δ 12 w 12 ∗ δ 12 w 13 ∗ δ 12 w 21 ∗ δ 12 w 22 ∗ δ 12 w 23 ∗ δ 12 w 31 ∗ δ 12 w 32 ∗ δ 12 w 33 ∗ δ 12 )


d21=d21d31d41d22d32d42d23d33d43=w11w21w31w12w22w32w13w23w33δ21=w11δ21w21δ21w31δ21w12δ21w22δ21w32δ21w13δ21w23δ21w33δ21 d 21 = ( d 21 d 22 d 23 d 31 d 32 d 33 d 41 d 42 d 43 ) = ( w 11 w 12 w 13 w 21 w 22 w 23 w 31 w 32 w 33 ) ∗ δ 21 = ( w 11 ∗ δ 21 w 12 ∗ δ 21 w 13 ∗ δ 21 w 21 ∗ δ 21 w 22 ∗ δ 21 w 23 ∗ δ 21 w 31 ∗ δ 21 w 32 ∗ δ 21 w 33 ∗ δ 21 )


d22=d22d32d42d23d33d43d24d34d44=w11w21w31w12w22w32w13w23w33δ22=w11δ22w21δ22w31δ22w12δ22w22δ22w32δ22w13δ22w23δ22w33δ22 d 22 = ( d 22 d 23 d 24 d 32 d 33 d 34 d 42 d 43 d 44 ) = ( w 11 w 12 w 13 w 21 w 22 w 23 w 31 w 32 w 33 ) ∗ δ 22 = ( w 11 ∗ δ 22 w 12 ∗ δ 22 w 13 ∗ δ 22 w 21 ∗ δ 22 w 22 ∗ δ 22 w 23 ∗ δ 22 w 31 ∗ δ 22 w 32 ∗ δ 22 w 33 ∗ δ 22 )


然后,将 d11,d12,d21,d22 d 11 , d 12 , d 21 , d 22 中相应 di,j d i , j 的位置相加,得到:

d11=w11δ11 d 11 = w 11 ∗ δ 11

... . . .

d22=w22δ11+w21δ12+w12δ21+w11δ22 d 22 = w 22 ∗ δ 11 + w 21 ∗ δ 12 + w 12 ∗ δ 21 + w 11 ∗ δ 22

d23=w23δ11+w22δ12+w13δ21+w12δ22 d 23 = w 23 ∗ δ 11 + w 22 ∗ δ 12 + w 13 ∗ δ 21 + w 12 ∗ δ 22

d32=w32δ11+w31δ12+w22δ21+w21δ22 d 32 = w 32 ∗ δ 11 + w 31 ∗ δ 12 + w 22 ∗ δ 21 + w 21 ∗ δ 22

d23=w33δ11+w32δ12+w23δ21+w22δ22 d 23 = w 33 ∗ δ 11 + w 32 ∗ δ 12 + w 23 ∗ δ 21 + w 22 ∗ δ 22

... . . .

d44=w33δ22 d 44 = w 33 ∗ δ 22

可以看出,我们的计算结果与第一个公式一样:

J(W,b)al=J(W,b)al11J(W,b)al21J(W,b)al12J(W,b)al22=(d22d32d23d33) ∂ J ( W , b ) ∂ a l = ( ∂ J ( W , b ) ∂ a 11 l ∂ J ( W , b ) ∂ a 12 l ∂ J ( W , b ) ∂ a 21 l ∂ J ( W , b ) ∂ a 22 l ) = ( d 22 d 23 d 32 d 33 )

注意:计算方法二不需要rot180(w)

2.1.代码:

      for i in range(H_out):
            for j in range(W_out):
                x_pad_masked = x_pad[:, :, i * stride:i * stride + HH, j * stride:j * stride + WW]
                for k in range(F):  # compute dw
                    dw[k, :, :, :] += np.sum(x_pad_masked * (residual[:, k, i, j])[:, None, None, None], axis=0)  
                    # dw=pad(bottom_data)* top_diff_ij
                for n in range(N):  # compute dx_pad
                    dx_pad[n, :, i * stride:i * stride + HH, j * stride:j * stride + WW] += np.sum((self.w[:, :, :, :] * (residual[n, :, i,j])[:, None, None, None]), axis=0)
                    # dx = (w)* (top_diff_ij)
  • 0
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 0
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值