1.问题描述
将训练好的模型使用 torch.save(model, path)命令进行保存后,通过Trained_model = torch.load(path)命令进行加载,常出现错误。具体代码和出现错误详情如下:
import torch
import torch.nn as nn
#引进训练好的模型进行测试
path = '/root/Save_model/bert_hide_model.pkl'
#模型加载
Trained_model = torch.load(path)
出现错误:
AttributeError: Can't get attribute 'BertClassificationModel' on <module '__main__'>
2.问题原因
保存下来的模型和参数不能在没有类定义时直接使用。使用pytorch
导入模型的时候有一个pickle
的操作,但是因为未知自定义的模型的结构,所以无法解析模型。
3.解决办法
将定义的模型类,加载到预测函数文件中即可。
import torch
import torch.nn as nn
#将定义好的模型类结构粘贴到当前文件中即可。
class BertClassificationModel(nn.Module):
def __init__(self):
super(BertClassificationModel, self).__init__()
#加载预训练模型
pretrained_weights="/root/Bert/chinese_roberta_wwm_large/"
self.roberta = transformers.BertModel.from_pretrained(pretrained_weights)
for param in self.roberta.parameters():
param.requires_grad = True
#定义线性函数
self.dense = nn.Linear(1024, 14) #wwm_large默认的隐藏单元数是1024, 输出单元是14,表示二分类
def forward(self, input_ids,token_type_ids,attention_mask):
#得到bert_output
bert_output = self.roberta(input_ids=input_ids,token_type_ids=token_type_ids, attention_mask=attention_mask)
#获得预训练模型的输出
bert_cls_hidden_state = bert_output[1]
#将768维的向量输入到线性层映射为向量
linear_output = self.dense(bert_cls_hidden_state)
return linear_output
#引进训练好的模型进行测试
path = '/root/Save_model/bert_hide_model.pkl'
#模型加载
Trained_model = torch.load(path)