原代码
def load(self, model_path=None):
if (model_path):
self.logger.info('load_model_path: ' + model_path)
#model_state_dict_save = {k.replace('module.',''):v for k,v in torch.load(model_path).items()}
model_state_dict_save = {k:v for k,v in torch.load(model_path, map_location=self.device).items()}
model_state_dict = self.model.state_dict()
model_state_dict.update(model_state_dict_save)
self.model.load_state_dict(model_state_dict)
改为: 添加False
def load(self, model_path=None):
if (model_path):
self.logger.info('load_model_path: ' + model_path)
#model_state_dict_save = {k.replace('module.',''):v for k,v in torch.load(model_path).items()}
model_state_dict_save = {k:v for k,v in torch.load(model_path, map_location=self.device).items()}
model_state_dict = self.model.state_dict()
model_state_dict.update(model_state_dict_save)
self.model.load_state_dict(model_state_dict,False)
参考pytorch加载模型报错RuntimeError:Error(s) in loading state_dict for DataParallel_qq_29631521的博客-CSDN博客