RuntimeError: Error(s) in loading state_dict for DataParallel:
Missing key(s) in state_dict: "module.model.Frequency_Edge_Module1.UAM.bn.weight", "module.model.Frequency_Edge_Module1.UAM.bn.bias", "module.model.Frequency_Edge_Module1.UAM.bn.running_mean", "module.model.Frequency_Edge_Module1.UAM.bn.running_var",
使用多gpu训练,保存的模型key前面会加module.
这样去掉
# 多gpu前面会加module
new_state = {}
state_dict = torch.load(model_path, map_location=torch.device('cpu'))
for key, value in state_dict.items():
new_state[key.replace('module.', '')] = value
self.model.load_state_dict(new_state)