[MXNet]Lecture04批量归一化

从零开始实现:


from mxnet import ndarray as nd

def pure_batch_norm(X,gemma,beta,eps=1e-5):
	assert len(X.shape) in (2,4)
	if len(X.shape)==2:
		mean=X.mean(axis=0)
		variance=((X-mean)**2).mean(axis=0)

	else:
		mean=X.mean(axis=(0,2,3),keepdims=True)
		variance=((X-mean)**2).mean(axis=(0,2,3),keepdims=True)

	print("mean",mean)
	print("gemma",gemma)
	print("gemma.reshape",gemma.reshape(mean.shape))

	x_hat=(X-mean)/nd.sqrt(variance+eps)
	return gemma.reshape(mean.shape)*x_hat+beta.reshape(mean.shape)

X=nd.arange(6).reshape((3,2))
y=pure_batch_norm(X,gemma=nd.array([1,1]),beta=nd.array([0,0]))
print(y)
X2=nd.arange(36).reshape((1,4,3,3))
y2=pure_batch_norm(X2,gemma=nd.array([1,1,1,1]),beta=nd.array([0,0,0,0]))
print(y2)

对于全联接层来说,输入是二维,要针对每一个特征进行归一化,所以有几列(axis=1方向),gemma就有几个数。

对于卷积层来说,输入是四维,针对每一个channel进行归一化,所以有几个channel(axis=1方向),gemma就有几个数。

如果把测试时的情况考虑进来,可以用移动平均的方法:

from mxnet import ndarray as nd
def batch_norm(x,gemma,beta,istraining,moving_mean,moving_variance,moving_momentum,eps=1e-5):
	assert len(x.shape) in (2,4)
	if len(x.shape)==2:
		mean=x.mean(axis=0)
		variance=((x-mean)**2).mean(axis=0)
	else:
		mean=x.mean(axis=(0,2,3),keepdims=True)
		variacne=((x-mean)**2).mean(axis=(0,2,3),keepdims=True)
		moving_mean=moving_mean.reshape(mean.shape)
		moving_variance=moving_variance.reshape(variance.shape)
	print(x)

	if istraining:
		x_hat=(x-mean)/nd.sqrt(variance+eps)
		moving_mean[:]=moving_momentum*moving_mean+(1.-moving_momentum)*mean	
		moving_variance[:]=moving_momentum*moving_variance+(1.-moving_momentum)*variance

	else:
		x_hat=(x-moving_mean)/nd.sqrt(moving_variance+eps)
	print('moving_mean',moving_mean)
	print('moving_variacne',moving_variance)
	print(x_hat)
	return x_hat*gemma.reshape(mean.shape)+beta.reshape(mean.shape)
moving_mean=nd.zeros(2)
moving_variance=nd.zeros(2)



gluon实现



  • 0
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值