最近在看yolov4的pytorch源码,下面的这段代码,有些疑问,涉及到了running_mean和running_var
def load_conv_bn(buf, start, conv_model, bn_model):
num_w = conv_model.weight.numel()
num_b = bn_model.bias.numel()
bn_model.bias.data.copy_(torch.from_numpy(buf[start:start + num_b]));
start = start + num_b
bn_model.weight.data.copy_(torch.from_numpy(buf[start:start + num_b]));
start = start + num_b
bn_model.running_mean.copy_(torch.from_numpy(buf[start:start + num_b]));
start = start + num_b
bn_model.running_var.copy_(torch.from_numpy(buf[start:start + num_b]));
start = start + num_b
conv_model.weight.data.copy_(torch.from_numpy(buf[start:start + num_w]).reshape(conv_model.weight.data.shape));
start = start + num_w
return start
- pytorch在打印网络参数的时候,只打出weight和bias这两个参数。但是,BN层应该是有四个参数,因为pytorch中只有可学习的参数才称为parameter,而running_mean和running_var则是在前向时先由X计算出mean和var,再由mean和var以动量momentum来更新running_mean和running_var。
- BN层中还会保存更新动量momentum和防止数值计算错误的eps
- 在训练阶段,running_mean和running_var在每次前向时更新一次;在测试阶段,则通过net.eval()固定该BN层的running_mean和running_var,此时这两个值即为训练阶段最后一次前向时确定的值,并在整个测试阶段保持不变。
def batch_norm(self, x):
"""
:param x: 数据
:return: BN输出
"""
x_mean = x.mean(axis=0)
x_var = x.var(axis=0)
# 对应running_mean的更新公式
# running_mean = (1 - momentum) * mean_old + momentum * mean_new
# running_var = (1 - momentum) * var_old + momentum * var_new
self._running_mean = (1-self._momentum)*self._running_mean + self._momentum*x_mean
self._running_var = (1-self._momentum)*self._running_var + self._momentum*x_var
# 对应论文中计算BN的公式
x_hat = (x-x_mean)/np.sqrt(x_var+self._eps)
y = self._gamma*x_hat + self._beta
return y
之前在Batch Normalization有一些关于BN的理解,可以结合看。