BN层的均值和标准差的shape是什么样子的?
首先,BN的工作原理是:
# t is the incoming tensor of shape [B, H, W, C]
# mean and stddev are computed along 0 axis and have shape [H, W, C]
mean = mean(t, axis=0)
stddev = stddev(t, axis=0)
for i in 0..B-1:
out[i,:,:,:] = norm(t[i,:,:,:], mean, stddev)
可以看到,均值和方差是[H, W, C]的样子,只在B这个维度上做gather。
但是,Conv层有一个特点,那就是权重共享,卷积核的shape是[h,w,c]的,这并不是全连接的(不是每个像素都有单独的权重),这个卷积核会划过整个图像,因此,图像处理中的BN操作,也就没有理由针对每个像素单独设计,而是也采用和卷积类似的共享参数方法:
# t is still the incoming tensor of shape [B, H, W, C]
# but mean and stddev are computed along (0, 1, 2) axes and have just [C] shape
mean = mean(t, axis=(0, 1, 2))
stddev = stddev(t, axis=(0, 1, 2))
for i in 0..B-1, x in 0..H-1, y in 0..W-1:
out[i,x,y,:] = norm(t[i,x,y,:], mean, stddev)
换句话说,均值和方差的形状实际上是[C]这样的,[B,H,W]三个维度均需要gather。
以上结论来自stack-overflow.