从零开始实现:
from mxnet import ndarray as nd
def pure_batch_norm(X,gemma,beta,eps=1e-5):
assert len(X.shape) in (2,4)
if len(X.shape)==2:
mean=X.mean(axis=0)
variance=((X-mean)**2).mean(axis=0)
else:
mean=X.mean(axis=(0,2,3),keepdims=True)
variance=((X-mean)**2).mean(axis=(0,2,3),keepdims=True)
print("mean",mean)
print("gemma",gemma)
print("gemma.reshape",gemma.reshape(mean.shape))
x_hat=(X-mean)/nd.sqrt(variance+eps)
return gemma.reshape(mean.shape)*x_hat+beta.reshape(mean.shape)
X=nd.arange(6).reshape((3,2))
y=pure_batch_norm(X,gemma=nd.array([1,1]),beta=nd.array([0,0]))
print(y)
X2=nd.arange(36).reshape((1,4,3,3))
y2=pure_batch_norm(X2,gemma=nd.array([1,1,1,1]),beta=nd.array([0,0,0,0]))
print(y2)
对于全联接层来说,输入是二维,要针对每一个特征进行归一化,所以有几列(axis=1方向),gemma就有几个数。
对于卷积层来说,输入是四维,针对每一个channel进行归一化,所以有几个channel(axis=1方向),gemma就有几个数。
如果把测试时的情况考虑进来,可以用移动平均的方法:
from mxnet import ndarray as nd
def batch_norm(x,gemma,beta,istraining,moving_mean,moving_variance,moving_momentum,eps=1e-5):
assert len(x.shape) in (2,4)
if len(x.shape)==2:
mean=x.mean(axis=0)
variance=((x-mean)**2).mean(axis=0)
else:
mean=x.mean(axis=(0,2,3),keepdims=True)
variacne=((x-mean)**2).mean(axis=(0,2,3),keepdims=True)
moving_mean=moving_mean.reshape(mean.shape)
moving_variance=moving_variance.reshape(variance.shape)
print(x)
if istraining:
x_hat=(x-mean)/nd.sqrt(variance+eps)
moving_mean[:]=moving_momentum*moving_mean+(1.-moving_momentum)*mean
moving_variance[:]=moving_momentum*moving_variance+(1.-moving_momentum)*variance
else:
x_hat=(x-moving_mean)/nd.sqrt(moving_variance+eps)
print('moving_mean',moving_mean)
print('moving_variacne',moving_variance)
print(x_hat)
return x_hat*gemma.reshape(mean.shape)+beta.reshape(mean.shape)
moving_mean=nd.zeros(2)
moving_variance=nd.zeros(2)
gluon实现