Expected to have finished reduction in the prior iteration before starting a new one. This error indicates that your module has parameters that were not used in producing its output (the return value of
forward
). You can enable unused parameter detection by passing the keyword argumentfind_unused_parameters=True
totorch.nn.parallel.DistributedDataParallel
. If you already have this argument set, then the distributed data parallel module wasn’t able to locate the output tensors in the return value of your module’s
forward
function. Please include the structure of the return value offorward
of your module when reporting this issue (e.g. list,
dict, iterable)
这个错误的重现:
原始代码:
n_finetune_classes = 40
model = torch.nn.DataParallel(model, device_ids=None)
pretrain = torch.load("pretrain_path")
model.load_state_dict(pretrain['state_dict'], strict=False)
# 由于预训练模型的fc层是原始的,这里我需要finetune,所以修改了fc层
model.module.fc = torch.nn.Linear(model.module.fc.in_features,
n_finetune_classes)
model.module.fc = model.module.fc.cuda()
后来,由于采用了分布式训练的方式于是有以下代码:
n_finetune_classes = 40
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=None)
pretrain = torch.load("pretrain_path")
model.load_state_dict(pretrain['state_dict'], strict=False)
# 由于预训练模型的fc层是原始的,这里我需要finetune,所以修改了fc层
model.module.fc = torch.nn.Linear(model.module.fc.in_features,
n_finetune_classes)
model.module.fc = model.module.fc.cuda()
分析
于是就出现了上述错误,这个错误的字面意思是说,有模型没有累积的参数,也就是说有一些参数没有被加入到模型的参数中。所以可以得到初步结论,使用分布式的框架时候某些参数某有更新,但是这些参数在其他进程中可能会更新,从这个错误来看DistributedDataParallel希望你使用find_unused_parameters=True
来让其他子进程都不更新这个不用更新的参数。
定位
由于没有修改fc之前没有这个错误,所以基本上可以认为是:修改的这个fc层,没有被涵盖进DistributedDataParallel这个容器包裹的module中。于是查阅官方文档找到:
那为什么DataParallel没有这个问题? DP的方式是有个主卡,负责分发,也就是说其他卡是从主卡分发下去的。但是DDP是在你用它包裹model时,会为model的体得注册梯度聚合函数,这个函数就是用来将多卡梯度求和平均得到的。
修复
n_finetune_classes = 40
model.fc = torch.nn.Linear(model.fc.in_features,
n_finetune_classes)
model.fc = model.fc.cuda()
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=None)
pretrain = torch.load("pretrain_path")
model.load_state_dict(pretrain['state_dict'], strict=False)