def load_param(self, model_path):
param_dict = torch.load(model_path)
# =============================================================================
# for i in param_dict:
# if 'fc' in i:
# continue
# self.state_dict()[i].copy_(param_dict[i])
# =============================================================================
# 源代码是多GPU训练,单gpu时出问题
for i in param_dict:
j = i.replace("base.","")
if 'fc' in i:
continue
if j in self.state_dict().keys():
self.state_dict()[j].copy_(param_dict[i])
注释内的在运行时报错,KeyError: ‘base.conv1.weight’,是由于多GPU训练保存的参数单GPU的环境无法直接用。改为注释下面的