Pytorch如何加载多卡预训练模型以及如何解决此次训练所用卡数与之前卡数不同的问题?
当我们使用Pytorch多卡训练得到一个模型然后再使用单卡加载时可能就会报错。或者多卡加载单卡训练的模型也有可能出错,这主要是因为训练的模型记录了相关的信息,那么如何解决如何问题呢,只需要按照如下方式进行预训练模型的加载即可。
# If specified we start from checkpoint
if opt.pretrained_weights:
'''
pretrained_dict = torch.load(opt.pretrained_weights)
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k.startswith('backbone')}
model_dict = model.state_dict()
model_dict.update(pretrained_dict)
model.load_state_dict(model_dict)
#'''
# all weights
#'''
pretrained_dict = torch.load(opt.pretrained_weights)
pretrained_dict = {
k.replace('module.',''): v for k, v in
pretrained_dict.items()
}
model_dict = model.state_dict()
model_dict.update(pretrained_dict)
model.load_state_dict(model_dict)
#'''