【工程实践】“AttributeError: Can‘t get attribute ‘BertClassificationModel‘ on <module ‘__main__‘>“加载模型报错

1.问题描述

将训练好的模型使用 torch.save(model, path)命令进行保存后,通过Trained_model = torch.load(path)命令进行加载,常出现错误。具体代码和出现错误详情如下:

import torch
import torch.nn as nn


#引进训练好的模型进行测试
path = '/root/Save_model/bert_hide_model.pkl'

#模型加载
Trained_model = torch.load(path)

出现错误:

AttributeError: Can't get attribute 'BertClassificationModel' on <module '__main__'>

2.问题原因

保存下来的模型和参数不能在没有类定义时直接使用。使用pytorch导入模型的时候有一个pickle的操作,但是因为未知自定义的模型的结构,所以无法解析模型。

3.解决办法

        将定义的模型类,加载到预测函数文件中即可。


import torch
import torch.nn as nn

#将定义好的模型类结构粘贴到当前文件中即可。
class BertClassificationModel(nn.Module):
    def __init__(self):
        super(BertClassificationModel, self).__init__()   
        #加载预训练模型
        pretrained_weights="/root/Bert/chinese_roberta_wwm_large/"
        self.roberta = transformers.BertModel.from_pretrained(pretrained_weights)
        for param in self.roberta.parameters():
            param.requires_grad = True
        #定义线性函数      
        self.dense = nn.Linear(1024, 14)  #wwm_large默认的隐藏单元数是1024, 输出单元是14,表示二分类
        
    def forward(self, input_ids,token_type_ids,attention_mask):
        #得到bert_output
        bert_output = self.roberta(input_ids=input_ids,token_type_ids=token_type_ids, attention_mask=attention_mask)
        #获得预训练模型的输出
        bert_cls_hidden_state = bert_output[1]
        #将768维的向量输入到线性层映射为向量
        linear_output = self.dense(bert_cls_hidden_state)
        return  linear_output

#引进训练好的模型进行测试
path = '/root/Save_model/bert_hide_model.pkl'

#模型加载
Trained_model = torch.load(path)
  • 5
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 1
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值