摘要
本文使用纯 Python 和 PyTorch 对比实现 Batch Normalization 函数及其反向传播.
相关
原理和详细解释, 请参考文章 :
Batch Normalization函数详解及反向传播中的梯度求导
系列文章索引 :
https://blog.csdn.net/oBrightLamp/article/details/85067981
正文
import torch
import numpy as np
class BatchNorm1d:
def __init__(self):
self.eps = 1e-5
self.weight = None
self.bias = None
self.num = None
self.std = None
self.dw = None
self.db = None
def __call__(self, x):
self.num = np.shape(x)[0]
mean = np.mean(x, axis=0, keepdims=True)
var = np.var(x, axis=0, keepdims=True)
self.sqrt = np.sqrt(var + self.eps)
self.std = (x - mean) / self.sqrt
out = self.std * self.weight + self.bias
return out