BN、CBN、CmBN 的对比与总结
最近看到了关于 Yolo 系列 trick 的总结文章 【Make YOLO Great Again】YOLOv1-v7全系列大解析(Tricks篇),其中提到了 YoloV4 中使用了 CmBN,这是对 CBN 的改进,可以较好的适应小 batch 的情形。论文中给出了一个简要的对比图:
这里结合此图对 BN 和其两种改进策略进行说明。所以需要注意的是,这里存在两个 batch 相关的概念:
- batch:指代与 BN 层的统计量 实际想要相对应的数据池,也就是图片样本数。
- mini-batch:由于整个 batch 独立计算时,受到资源限制可能不现实,于是需要将 batch 拆分成数个 mini-batch,每个 mini-batch 单独计算后汇总得到整个 batch 的统计量。从而归一化特征。
我们日常在分割或者检测中使用 BN 时,此时如果不使用特殊的设定,那么 batch 与 mini-batch 是一样的。CBN 和 CmBN 所做的就是如何使用多个独立的 mini-batch 的数据获得一个近似于更大 batch 的统计量以提升学习效果。
CBN 与 CmBN
CmBN(Cross mini-Batch Normalization)是 CBN(Cross-Iteration Batch Normalization)的修改版。
CBN 主要用来解决在 Batch-Size 较小时,BN 的效果不佳问题。CBN 连续利用多个迭代的数据来变相扩大 batch size 从而改进模型的效果。这种用前几个 iteration 计算好的统计量来计算当前迭代的 BN 统计量的方法会有一个问题:过去的 BN 参数是由过去的网络参数计算出来的特征而得到的,而本轮迭代中计算 BN 时,它们的模型参数其实已经过时了。
假定 batch=4*mini batch,CBN 在 t t t 次迭代:
- 模型基于之前的梯度被更新。此时的 BN 的仿射参数也是最新的。
- 除了本次迭代的统计量,也会使用通过补偿后的前 3 次迭代得到的统计量。这 4 次的统计量会被一起用来得到近似于整个窗口的近似 batch 的 BN 的统计量。
- 使用得到的近似统计量归一化特征。
- 使用当前版本的仿射参数放缩和偏移。
CmBN 是基于 CBN 改进的,按照论文的图示的意思,主要的差异在于从滑动窗口变为固定窗口。每个 batch 中的统计不会使用 batch 之前的迭代的信息,仅会累积该窗口内的 4 次迭代以用于最后一次迭代的更新。这一策略基本与梯度累积策略仍有不同,梯度累加仅仅累加了梯度,但是前面的图中明显可以看到 BN 的统计量实际上也累积了起来,而图 4 中的展现的 BN 似乎更像是梯度累积。
CBN 的实现
# https://github.com/Howal/Cross-iterationBatchNorm/blob/f6d35301789c96e52699a9cbc8d2de8681547770/mmdet/models/utils/CBN.py#L74
def forward(self, input, weight):
# deal with wight and grad of self.pre_dxdw!
self._check_input_dim(input)
y = input.transpose(0, 1)
return_shape = y.shape
y = y.contiguous().view(input.size(1), -1)
# burnin
if self.training and self.burnin > 0:
self.iter_count += 1
self._update_buffer_num()
if self.buffer_num > 0 and self.training and input.requires_grad: # some layers are frozen!
# cal current batch mu and sigma
cur_mu = y.mean(dim=1)
cur_meanx2 = torch.pow(y, 2).mean(dim=1)
cur_sigma2 = y.var(dim=1)
# cal dmu/dw dsigma2/dw
dmudw = torch.autograd.grad(cur_mu, weight, self.ones, retain_graph=True)[0]
dmeanx2dw = torch.autograd.grad(cur_meanx2, weight, self.ones, retain_graph=True)[0]
# update cur_mu and cur_sigma2 with pres
mu_all = torch.stack([cur_mu, ] + [tmp_mu + (self.rho * tmp_d * (weight.data - tmp_w)).sum(1).sum(1).sum(1) for tmp_mu, tmp_d, tmp_w in zip(self.pre_mu, self.pre_dmudw, self.pre_weight)])
meanx2_all = torch.stack([cur_meanx2, ] + [tmp_meanx2 + (self.rho * tmp_d * (weight.data - tmp_w)).sum(1).sum(1).sum(1) for tmp_meanx2, tmp_d, tmp_w in zip(self.pre_meanx2, self.pre_dmeanx2dw, self.pre_weight)])
sigma2_all = meanx2_all - torch.pow(mu_all, 2)
# with considering count
re_mu_all = mu_all.clone()
re_meanx2_all = meanx2_all.clone()
re_mu_all[sigma2_all < 0] = 0
re_meanx2_all[sigma2_all < 0] = 0
count = (sigma2_all >= 0).sum(dim=0).float()
mu = re_mu_all.sum(dim=0) / count
sigma2 = re_meanx2_all.sum(dim=0) / count - torch.pow(mu, 2)
self.pre_mu = [cur_mu.detach(), ] + self.pre_mu[:(self.buffer_num - 1)]
self.pre_meanx2 = [cur_meanx2.detach(), ] + self.pre_meanx2[:(self.buffer_num - 1)]
self.pre_dmudw = [dmudw.detach(), ] + self.pre_dmudw[:(self.buffer_num - 1)]
self.pre_dmeanx2dw = [dmeanx2dw.detach(), ] + self.pre_dmeanx2dw[:(self.buffer_num - 1)]
tmp_weight = torch.zeros_like(weight.data)
tmp_weight.copy_(weight.data)
self.pre_weight = [tmp_weight.detach(), ] + self.pre_weight[:(self.buffer_num - 1)]
else:
x = y
mu = x.mean(dim=1)
cur_mu = mu
sigma2 = x.var(dim=1)
cur_sigma2 = sigma2
if not self.training or self.FROZEN:
y = y - self.running_mean.view(-1, 1)
# TODO: outside **0.5?
if self.out_p:
y = y / (self.running_var.view(-1, 1) + self.eps)**.5
else:
y = y / (self.running_var.view(-1, 1)**.5 + self.eps)
else:
if self.track_running_stats is True:
with torch.no_grad():
self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * cur_mu
self.running_var = (1 - self.momentum) * self.running_var + self.momentum * cur_sigma2
y = y - mu.view(-1, 1)
# TODO: outside **0.5?
if self.out_p:
y = y / (sigma2.view(-1, 1) + self.eps)**.5
else:
y = y / (sigma2.view(-1, 1)**.5 + self.eps)
y = self.weight.view(-1, 1) * y + self.bias.view(-1, 1)
return y.view(return_shape).transpose(0, 1)