running_mean是各通道的样本像素均值,也就是shape为(batch_size, C_out, H, W )的输入,running_mean的shape为(C_out,).
Pytorch对BatchNorm2d的官方解释文档如上所示,一般momentum是0.1,所以:
running_mean = 0.1 * mean + 0.9 * running_mean
也就是说,当前批次在某通道的均值 * 0.1 + 之前得到的running_mean * 0.9,而如果是第0批次,之前得到的running_mean初始化为0.0,也就是第0批次的running_mean = 0.1 * mean.
eval时直接使用训练时得到的running_mean。