Batch Normalization的反向传播详细解说
Batch Normalization在我的另一篇博客中已经详细说明了,而这篇我将详细介绍下Batch Normalization的反向传播的细节。
先贴张前向和反向传播图:
从左到右,沿着黑色箭头向前传播。输入是一个矩阵X, γ \gamma γ和 β \beta β作为向量。从右到左,沿着红色箭头反向传播,将梯度从上一层分布到 γ \gamma γ和 β \beta β,并一直返回到输入。
篇外话,下图中简单说明了正向传播和反向传播的示意图,如果看不懂,建议百度搜索链式传导进行基础学习。
下面列出第一张图中的反向传播的每块详细内容,正因为是链式反向传播,所以得从后往前说。这里只贴图,不进行详细说明了(因为能看得懂)
步骤9:
步骤8:
步骤7:
步骤6:
步骤5:
步骤4:
步骤3:
步骤2:
步骤1:
步骤0:
代码:
def batchnorm_backward(dout, cache):
#unfold the variables stored in cache
xhat,gamma,xmu,ivar,sqrtvar,var,eps = cache
#get the dimensions of the input/output
N,D = dout.shape
#step9
dbeta = np.sum(dout, axis=0)
dgammax = dout #not necessary, but more understandable
#step8
dgamma = np.sum(dgammax*xhat, axis=0)
dxhat = dgammax * gamma
#step7
divar = np.sum(dxhat*xmu, axis=0)
dxmu1 = dxhat * ivar
#step6
dsqrtvar = -1. /(sqrtvar**2) * divar
#step5
dvar = 0.5 * 1. /np.sqrt(var+eps) * dsqrtvar
#step4
dsq = 1. /N * np.ones((N,D)) * dvar
#step3
dxmu2 = 2 * xmu * dsq
#step2
dx1 = (dxmu1 + dxmu2)
dmu = -1 * np.sum(dxmu1+dxmu2, axis=0)
#step1
dx2 = 1. /N * np.ones((N,D)) * dmu
#step0
dx = dx1 + dx2
return dx, dgamma, dbeta