问题现场:
背景:Python3.6,pytorch1.+,使用huggingface.co工具去finetuning bert模型,在多gpu上训练保存模型,在cpu上加载去做预测。
加载模型,对同一条数据测试结果不一样。
当时检查了:1
1.模型的状态是否是model.eval()?————不是这个问题
2.是否使用torch.no_grad()
,这句代码的意思是使得本次过的样本不会更新梯度。————不是这个问题
model.eval()
with torch.no_grad():
logits = model(input_ids, token_type_ids=None, attention_mask=(input_ids > 0))
3.是否有随机成分在里面,比如dropout,导致每次的网络结构都不一样。————不是这个问题
huggingface.co有句解释:
Models are now set in evaluation mode by default when instantiated with the from_pretrained() method.
To train them don’t forget to set them back in training mode (model.train()) to activate the dropout modules.
意思是eval()状态下没有dropout操作,在model.train()状态下,才激活dropout操作。
4.是不是对输入进行encode时,编码的tensor不一致?————不是这个问题
5.本项目中,有两点是“自己拼凑的代码”,一个是模型保存,一个是模型加载。————是“拼凑代码”的问题
模型定义使用的huggingface,多gpu并行训练模型,保存的时候,不能使用huggingface工具中的保存属性,网上查了一些方案,使用torch.save()保存模型。
if isinstance(model, torch.nn.DataParallel):
torch.save(model.state_dict(), config.save_path)
预测加载模型时,结合huggingface和torch的文档,使用如下加载模型:
model_state_dict = torch.load(model_path, map_location=torch.device('cpu'))
self.model = BertForSequenceClassification.from_pretrained(self.bert_name, state_dict=model_state_dict)
从这两点入手,挨个查,挨个验证,总有一个是bug。
解决方案:
准备:仔细去看torch的save和load文档。会发现对应办法:
本次定位到的问题是:模型保存时,保存的不对。
将保存方法改为下面即可:
torch.save(model.module.state_dict(), PATH)
分析下原因:
官网的说法:
torch.nn.DataParallel is a model wrapper that enables parallel GPU utilization.
To save a DataParallel model generically, save the model.module.state_dict().
This way, you have the flexibility to load the model any way you want to any device you want.
torch.nn.DataParallel是一个并行使用GPU训练模型的封装器。要保存一个通用的datapar平行模型,使用model.module.state_dict()保存。通过这种方式,您可以灵活地以任何方式加载模型到任何您想要的设备。
另外:
当使用单gpu训练模型,按照torch.save(model.state_dict(), config.save_path)保存模型,去cpu上加载预测模型是可以的,打分前后一致。
总结:
遇见该问题的主要原因是:使用数据并行训练模型后,模型保存错误,model.module的保存原理,待查
参考文档:
1.huggingface:https://huggingface.co/transformers/main_classes/model.html#transformers.PreTrainedModel.from_pretrained
2.torch的save和load官网:https://pytorch.org/tutorials/beginner/saving_loading_models.html
3.module类:https://blog.csdn.net/qq_27825451/article/details/95888267