FrozenBatchNorm就是"weight" and "bias", "running_mean", "running_var”四个值固定住的BN
经典框架中一直使用的是FrozenBatchNorm2d。如Detectron,DETR, mmdetection?见
"weight" and "bias", "running_mean", "running_var”四个值是buf,通过register_buffer设置不更新。
为什么要使用FrozenBatchNorm
BN层在CNN网络中大量使用,但是BN依赖于均值和方差,如果batch_size太小,计算一个小batch_size的均值和方差,肯定没有计算大的batch_size的均值和方差稳定和有意义,这个时候,还不如不使用bn层,因此可以将bn层冻结。另外,我们使用的网络,几乎都是在imagenet上pre-trained,完全可以使用在imagenet上学习到的参数。
而且,如果使用的是FrozenBatchNorm,多卡训练就不会有BN同步的问题了,那么多卡训练的性能理论上应该和单卡一样好了,注意这点torchvision.ops.FrozenBatchNorm2d(num_features: int, eps: float = 1e-05)
Pytorch FrozenBatchNorm (BN)
最新推荐文章于 2024-04-20 19:38:36 发布