关于BN的具体定义可以看看论文:Batch Normalization:Accelerating Deep Network Training by Reducing Internal Covariate Shift
BN在训练阶段是保持计算每一个Batch的均值和方差的,其计算算法:
但是我们知道在训练过程中有很多的batch,怎么保证在训练和测试阶段的batch参数保持一致呢?
其实就是BN在训练阶段每一个batch按照如上图的算法流程计算每个batch的均值和方差,然后通过一个滑动平均值方法保存,在pytorch中是通过一个参数momentum保存每个连续计算的batch的均值和方差的
x' = (1-momentum)*x + momentum*x''
其中x'表示新的保存下来的值,x是之前旧的保存的值,x''是表示新的batch计算的当前均值方差等。
在BN论文中讲的测试阶段使用的参数是:
其中的均值是每个batch的均值的均值,方差是每个batch的无偏估计量。但是在pytorch具体实现是采用以上所说的滑动平均值方法计算的,所以最后一旦整个训练阶段完成,BN层中的所有参数也就固定下来,然后直接用于test。
在pytorch中由于BN具体实现的代码在cpp中,暂时没找到具体的实现在哪,以上只是由python接口api和doc猜测而来。大家有知道具体在哪实现可以留言。