在加载之前训练的模型时,发现加载模型的训练效果和随机初始化效果相同,经检查发现是load_state_dict函数中strict=False的原因,也就是说之前训练的参数并未加载到模型中训练,改为True后发现如下报错:
Missing key(s) in state_dict: "word_context1", "word_context2", "word_context3", "word_context4",....
Unexpected key(s) in state_dict: "module.word_context1", "module.word_context2", "module.word_context3", "module.word_context4", .....
发现多出module的参数,经查阅相关资料后改正:
model.load_state_dict({k.replace('module.',''):v for k,v in torch.load(model_save_path).items()})
恢复了之前训练效果