深度学习中的数学与技巧(4): BatchNormalization 代码实现

BatchNormalization是神经网络中常用的参数初始化的方法。其算法流程图如下: 
这里写图片描述

我们可以把这个流程图以门电路的形式展开,方便进行前向传播和后向传播: 
这里写图片描述

那么前向传播非常简单,直接给出代码:


   
   
  1. def batchnorm_forward(x, gamma, beta, eps):
  2. N, D = x.shape
  3. #为了后向传播求导方便,这里都是分步进行的
  4. #step1: 计算均值
  5. mu = 1./N * np.sum(x, axis = 0)
  6. #step2: 减均值
  7. xmu = x - mu
  8. #step3: 计算方差
  9. sq = xmu ** 2
  10. var = 1./N * np.sum(sq, axis = 0)
  11. #step4: 计算x^的分母项
  12. sqrtvar = np.sqrt(var + eps)
  13. ivar = 1./sqrtvar
  14. #step5: normalization->x^
  15. xhat = xmu * ivar
  16. #step6: scale and shift
  17. gammax = gamma * xhat
  18. out = gammax + beta
  19. #存储中间变量
  20. cache = (xhat,gamma,xmu,ivar,sqrtvar,var,eps)
  21. return out, cache
反向传播则是求导的过程,这里特别要小心,由于门电路中有多个支路,求导时要进行加和。

   
   
  1. def batchnorm_backward(dout, cache):
  2. #解压中间变量
  3. xhat,gamma,xmu,ivar,sqrtvar,var,eps = cache
  4. N,D = dout.shape
  5. #step6
  6. dbeta = np.sum(dout, axis=0)
  7. dgammax = dout
  8. dgamma = np.sum(dgammax*xhat, axis=0)
  9. dxhat = dgammax * gamma
  10. #step5
  11. divar = np.sum(dxhat*xmu, axis=0)
  12. dxmu1 = dxhat * ivar #注意这是xmu的一个支路
  13. #step4
  14. dsqrtvar = -1. /(sqrtvar**2) * divar
  15. dvar = 0.5 * 1. /np.sqrt(var+eps) * dsqrtvar
  16. #step3
  17. dsq = 1. /N * np.ones((N,D)) * dvar
  18. dxmu2 = 2 * xmu * dsq #注意这是xmu的第二个支路
  19. #step2
  20. dx1 = (dxmu1 + dxmu2) 注意这是x的一个支路
  21. #step1
  22. dmu = -1 * np.sum(dxmu1+dxmu2, axis=0)
  23. dx2 = 1. /N * np.ones((N,D)) * dmu 注意这是x的第二个支路
  24. #step0 done!
  25. dx = dx1 + dx2
  26. return dx, dgamma, dbeta



要注意的就是求导时遇到多个支路的情况要进行累加。表达式复杂的话还是分步进行比较不容易出错。

  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值