Pytorch采坑记录:DDP加载之前的checkpoint后loss上升(metric下降)

  最近在鼓捣使用pytorch的distributeddataparallel这个API搭一个数据并行的训练测试任务,过程中遇到了一个问题,做一下记录。

1、问题

  使用DDP打包了一个模型训练了一段时间,loss不断下降metric不断上升,一切都是很正常的现象。当因为意外暂停或者手动暂停更改学习率而停止了程序,再开启程序加载之前的checkpoint继续训练,却发现loss突然比之前上升或者metric比之前下降了很多。仔细看了一下loss的值,发现直接回到刚开始第一次训练模型时的水平,仿佛checkpoint根本没加载进去,是从初始化开始训练的一样。

2、原因分析

根据我之前的框架使用经验,认为可能的原因有以下两点:

2.1 模型的train和eval模式问题

  由于很多算子在训练模式和测试模式下的前向传播原理不同,例如batchnorm和dropout等,导致几乎所有的框架都会对模型设置一个train或eval的flag。Pytorch可以通过调用model.train()或model.eval()将模型的状态进行切换。在训练模式下如果模型是eval状态或者在推理模式下模型是train状态都会使得结果计算不正确,可能是导致上述问题的一个原因。
  但这个猜想很快就被我给否掉了。第一,按照我的经验,如果一个模型已经训练到一个比较好的状态,即便是搞混了train和eval的状态flag,结果虽然不对但是一般也不会差的特别多。我之前用0.1学习率训练了七八个epoch,损失已经到了0.4~0.5左右。再次训练加载checkpoint损失直接飙到了差不多快到8了,这个跳跃太大了。第二,我去check了一下我的代码,发现并没有出现train和eval搞混的问题(手动狗头)。

2.2 模型没有正确的加载进去

  出现上述问题的另外一个可能的原因是:Pytorch没有正确的将模型加载进去。经常使用pytorch的同学可能都遇到过这样一种情况:自己设计了一个网络用来做某项任务,选择了某个经典分类模型(如resnet等)的特征提取部分作为backbone。训练时在github上下载了已经在ImageNet数据集上pretrain的分类模型,并把这个模型的特征提取部分的权重直接加载到自己的模型中实现backbone预训练。但是效果却并不好,可能的原因之一就是backbone并没有成功的加载进去。
  Pytorch中模型参数的保存底层使用的是字典的结构,因此参数加载需要保证参数名必须是一一对应的。常用的一个加载模型参数的API是load_state_dict,其中有一个参数是strict=True,这个参数用来控制加载模型是否是“严格”的。严格指的是代码模型定义里的所有parameter和buffer必须和要加载的checkpoint里的parameter和buffer的参数名、参数维度、参数类型等能够一一对应上,一个都不能多也不能少,否则就会报错。strict=False则可以允许代码模型定义里的部分parameter或buffer和checkpoint中的对应不上,如果有能对应上的就加载,否则就忽略。比如下面的情况,当strict=False时,parameter2、3、4和5可以被正确加载,parameter1和6不会被加载而采用用户定义的方式初始化;当strict=True时,加载会报错。

  在我遇到的问题中,经过确认我排除了这个可能性。checkpoint是用我自定义的模型训练得到的而不是从网上下载的,模型定义我没有更改过因此和之前的是一样的,而且我设置了strict=True,也没有报错说明模型是被正确加载进去的。

2.3 DistributedDataParallel问题

  以上两种思考没有解决我的问题,此时我痛定思痛仔细回想一下整个过程。同样的代码同样的逻辑之前不做数据并行的时候是没有问题的,但是一做DistributedDataParallel训练就出现了问题,说明bug出在DistributedDataParallel这里。看了一下这个API的源码,找到了问题所在。在这个类的__init__函数里有这么一段:

class DistributedDataParallel(Module):
	def __init__(self, ...):
		...
		# Sync params and buffers
		self._sync_params_and_buffers(authoritative_rank=0)
        ...

也就是说在调用这个API把一个普通的model打包成一个ddp的model后,即实例化一个DistributedDataParallel对象的时候,就已经完成了模型的parameter和buffer在主进程模型和其他进程上replica的同步。而我的代码里,是先实例化了一个ddp对象,然后才去加载checkpoint

...
model = MyModel()
model.to(device=rank)
model = nn.parallel.DistributedDataParallel(model, devices=[rank])
if rank == 0:
	ret = model.load_state_dict(torch.load(xxx), strict=True)
...

此时代码的执行过程是:1、实例化一个MyModel对象并随机初始化;2、实例化一个ddp对象并用之前随机初始化的model去同步其他进程上replica的parameter和buffer;3、将checkpoint的parameter和buffer加载到主进程上的model中。此时其他几个进程上的model的parameter和buffer还都是随机初始化的,在前向和反向传播时虽然主进程上的model给出了类似之前checkpoint比较准确的结果。可是其他几个子进程上的模型由于参数是随机初始化的所以结果差的很远,各个进程上的梯度经过reduce_mean后就错的很离谱了。因此应该调整一下代码的顺序为:

...
model = MyModel()
model.to(device=rank)
if rank == 0:
	ret = model.load_state_dict(torch.load(xxx), strict=True)
model = nn.parallel.DistributedDataParallel(model, devices=[rank])
...

  此时仍然有一个小小的bug,就是通过DistributedDataParallel这个API去打包模型后,模型的所有参数的名字都会多一个module的前缀,还是看一下API的源码:

class DistributedDataParallel(Module):
	def __init__(self, module, ...):
		...
		self.module = module
        ...

熟悉Pytorch.nn.Module这个类的变量命名规则的同学应该知道,加了这个成员变量赋值的语句后,所有模型变量的名字前缀都会多一个module。比如MyModel()实例化的对象中有一个名为conv1.weight的参数,经过DDP打包后得到的新模型中,对应的参数变量名会变为module.conv1.weight,一种解决办法是可以通过保存模型时指定保存DDP对象的module模块来消除这个前缀。

  水平有限,欢迎讨论。

  • 9
    点赞
  • 11
    收藏
    觉得还不错? 一键收藏
  • 7
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 7
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值