【SyncBN踩坑】DDP训练efficientnet精度异常问题

最近使用DDP(多卡分布式训练)在训练efficientnet时,遇到一个问题:开启DDP和单卡训练efficientnet的精度差异明显:DDP下的训练精度远不如单卡的结果!

尤其当加载了Imagenet的预训练权重后再进行训练,精度差异会更加明显!!

首先需要说明的是,我的efficientnet结构调用自timm库0.6.7

先上调用代码

model = timm.create_model("efficientnet_b0", pretrained=True)

model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)

model = DistributedDataParallel(
            model,
            device_ids=[RANK],
            output_device=RANK,
            find_unused_parameters=True,
        )

问题就出在同步BN层这一行!!

class BatchNormAct2d(nn.BatchNorm2d):
    """BatchNorm + Activation

    This module performs BatchNorm + Activation in a manner that will remain backwards
    compatible with weights trained with separate bn, act. This is why we inherit from BN
    instead of composing it as a .bn member.
    """
    def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True,
                 apply_act=True, act_layer=nn.ReLU, inplace=True, drop_block=None):

timm库的efficientnet使用了上述的基于nn.BatchNorm2d 重写的bn方法,把act包了进去!

由于BatchNormAct2d 方法继承自nn.BatchNorm2d,因而还是属于_BatchNorm模块,所以在调用同步BN方法时,会把BatchNormAct2d直接替换成普通的并行BN层!从而丢失了本该存在的act!

class SyncBatchNorm(_BatchNorm):
    @classmethod
    def convert_sync_batchnorm(cls, module, process_group=None):
        module_output = module
        if isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
            module_output = torch.nn.SyncBatchNorm(...)

所以,使用timm库的efficientnet进行DDP训练时,请记得关闭SyncBN方法。或者使用timm自带的SynvBN方法:  timm.models.convert_sync_batchnorm()

同样的情况还有timm库的MobilenetV3,因为它也同样使用了BatchNormAct2d!

  • 2
    点赞
  • 2
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值