BatchNormalization是神经网络中常用的参数初始化的方法。其算法流程图如下:
我们可以把这个流程图以门电路的形式展开,方便进行前向传播和后向传播:
那么前向传播非常简单,直接给出代码:
def batchnorm_forward(x, gamma, beta, eps):
N, D = x.shape
#为了后向传播求导方便,这里都是分步进行的
#step1: 计算均值
mu = 1./N * np.sum(x, axis = 0)
#step2: 减均值
xmu = x - mu
#step3: 计算方差
sq = xmu ** 2
var = 1./N * np.sum(sq, axis = 0)
#step4: 计算x^的分母项
sqrtvar = np.sqrt(var + eps)
ivar = 1./sqrtvar
#step5: normalization->x^
xhat = xmu * ivar
#step6: scale and shift
gammax = gamma * xhat
out = gammax + beta
#存储中间变量
cache = (xhat,gamma,xmu,ivar,sqrtvar,var,eps)
return out, cache
反向传播则是求导的过程,这里特别要小心,由于门电路中有多个支路,求导时要进行加和。def batchnorm_backward(dout, cache):
#解压中间变量
xhat,gamma,xmu,ivar,sqrtvar,var,eps = cache
N,D = dout.shape
#step6
dbeta = np.sum(dout, axis=0)
dgammax = dout
dgamma = np.sum(dgammax*xhat, axis=0)
dxhat = dgammax * gamma
#step5
divar = np.sum(dxhat*xmu, axis=0)
dxmu1 = dxhat * ivar #注意这是xmu的一个支路
#step4
dsqrtvar = -1. /(sqrtvar**2) * divar
dvar = 0.5 * 1. /np.sqrt(var+eps) * dsqrtvar
#step3
dsq = 1. /N * np.ones((N,D)) * dvar
dxmu2 = 2 * xmu * dsq #注意这是xmu的第二个支路
#step2
dx1 = (dxmu1 + dxmu2) 注意这是x的一个支路
#step1
dmu = -1 * np.sum(dxmu1+dxmu2, axis=0)
dx2 = 1. /N * np.ones((N,D)) * dmu 注意这是x的第二个支路
#step0 done!
dx = dx1 + dx2
return dx, dgamma, dbeta
要注意的就是求导时遇到多个支路的情况要进行累加。表达式复杂的话还是分步进行比较不容易出错。