python
import numpy as np
class MyBatchNorm2d:
def __init__(self, num_features, eps=0.001, momentum=0.99):
self.gamma = np.ones([num_features, ])
self.beta = np.zeros([num_features, ])
self.eps = eps
self.momentum = momentum
self.running_mean = 0.0
self.running_var = 1.0
def forward_train(self, input):
assert input.dim == 4
mean = input.mean(axis=0)
var = input.var(axis=0)
input_norm = (input-mean)/np.sqrt(self.eps+var)
results = self.gamma * input_norm + self.beta
self.running_mean = self.running_mean*self.momentum + (1-self.momentum)*mean
self.running_var = self.running_var*self.momentum + (1-self.momentum)*var
return results
def forward_test(self, input):
assert input.dim == 4
input_norm = (input-self.running_mean) / np.sqrt(self.eps+self.running_var)
results = input_norm * self.gamma + self.beta
return results