最近使用DDP(多卡分布式训练)在训练efficientnet时,遇到一个问题:开启DDP和单卡训练efficientnet的精度差异明显:DDP下的训练精度远不如单卡的结果!
尤其当加载了Imagenet的预训练权重后再进行训练,精度差异会更加明显!!
首先需要说明的是,我的efficientnet结构调用自timm库0.6.7
先上调用代码
model = timm.create_model("efficientnet_b0", pretrained=True)
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
model = DistributedDataParallel(
model,
device_ids=[RANK],
output_device=RANK,
find_unused_parameters=True,
)
问题就出在同步BN层这一行!!
class BatchNormAct2d(nn.BatchNorm2d):
"""BatchNorm + Activation
This module performs BatchNorm + Activation in a manner that will remain backwards
compatible with weights trained with separate bn, act. This is why we inherit from BN
instead of composing it as a .bn member.
"""
def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, track_running_stats=True,
apply_act=True, act_layer=nn.ReLU, inplace=True, drop_block=None):
timm库的efficientnet使用了上述的基于nn.BatchNorm2d 重写的bn方法,把act包了进去!
由于BatchNormAct2d 方法继承自nn.BatchNorm2d,因而还是属于_BatchNorm模块,所以在调用同步BN方法时,会把BatchNormAct2d直接替换成普通的并行BN层!从而丢失了本该存在的act!
class SyncBatchNorm(_BatchNorm):
@classmethod
def convert_sync_batchnorm(cls, module, process_group=None):
module_output = module
if isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
module_output = torch.nn.SyncBatchNorm(...)
所以,使用timm库的efficientnet进行DDP训练时,请记得关闭SyncBN方法。或者使用timm自带的SynvBN方法: timm.models.convert_sync_batchnorm()
同样的情况还有timm库的MobilenetV3,因为它也同样使用了BatchNormAct2d!