深度学习中的数学与技巧(4): BatchNormalization 代码实现

BatchNormalization是神经网络中常用的参数初始化的方法。其算法流程图如下: 
这里写图片描述

我们可以把这个流程图以门电路的形式展开,方便进行前向传播和后向传播: 
这里写图片描述

那么前向传播非常简单,直接给出代码:

def batchnorm_forward(x, gamma, beta, eps):

  N, D = x.shape
  #为了后向传播求导方便,这里都是分步进行的
  #step1: 计算均值
  mu = 1./N * np.sum(x, axis = 0)

  #step2: 减均值
  xmu = x - mu

  #step3: 计算方差
  sq = xmu ** 2
  var = 1./N * np.sum(sq, axis = 0)

  #step4: 计算x^的分母项
  sqrtvar = np.sqrt(var + eps)
  ivar = 1./sqrtvar

  #step5: normalization->x^
  xhat = xmu * ivar

  #step6: scale and shift
  gammax = gamma * xhat
  out = gammax + beta

  #存储中间变量
  cache =  (xhat,gamma,xmu,ivar,sqrtvar,var,eps)

  return out, cache
反向传播则是求导的过程,这里特别要小心,由于门电路中有多个支路,求导时要进行加和。
def batchnorm_backward(dout, cache):

  #解压中间变量
  xhat,gamma,xmu,ivar,sqrtvar,var,eps = cache

  N,D = dout.shape

  #step6
  dbeta = np.sum(dout, axis=0)
  dgammax = dout
  dgamma = np.sum(dgammax*xhat, axis=0)
  dxhat = dgammax * gamma

  #step5
  divar = np.sum(dxhat*xmu, axis=0)
  dxmu1 = dxhat * ivar #注意这是xmu的一个支路

  #step4
  dsqrtvar = -1. /(sqrtvar**2) * divar
  dvar = 0.5 * 1. /np.sqrt(var+eps) * dsqrtvar

  #step3
  dsq = 1. /N * np.ones((N,D)) * dvar
  dxmu2 = 2 * xmu * dsq #注意这是xmu的第二个支路

  #step2
  dx1 = (dxmu1 + dxmu2) 注意这是x的一个支路


  #step1
  dmu = -1 * np.sum(dxmu1+dxmu2, axis=0)
  dx2 = 1. /N * np.ones((N,D)) * dmu 注意这是x的第二个支路

  #step0 done!
  dx = dx1 + dx2

  return dx, dgamma, dbeta



要注意的就是求导时遇到多个支路的情况要进行累加。表达式复杂的话还是分步进行比较不容易出错。

  • 0
    点赞
  • 5
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值