KBQA-Bert学习记录-构建BERT-CRF模型

目录

1.__init__方法

2.forward方法


将bert和crf模型结合起来,简单来说就是,设置好Bert模型,以及参数,得到的输出结果给CRF模型即可。

1.__init__方法

这里面主要是bert的参数的定义及导入,还有bert模型的导入。

MODEL_NAME = 'bert-base-chinese-model.bin'
CONFIG_NAME = 'bert-base-chinese-config.json'
VOB_NAME = 'bert-base-chinese-vocab.txt'


class BertCrf(nn.Module):
    def __init__(self, config_name: str, model_name:str = None, num_tags: int = 2, batch_first: bool = True) -> None:
        self.batch_first = batch_first
        # 模型配置文件、模型预训练参数文件判断
        if not os.path.exists(config_name):
            raise ValueError(
                "未找到模型配置文件 '{}'".format(config_name)
            )
        else:
            self.config_name = config_name
        if model_name is not None:
            if not os.path.exists(model_name):
                raise ValueError(
                    "未找到模型预训练参数文件 '{}'".format(model_name)
                )
            else:
                self.model_name = model_name
        else:
            self.model_name = None
        if num_tags <= 0:
            raise ValueError(f'invalid number of tags: {num_tags}')
        super().__init__()

        # 配置bert的config文件
        self.bert_config = BertConfig.from_pretrained(self.config_name)
        self.bert_config.num_labels = num_tags
        self.model_kwargs = {'config': self.bert_config}

        # 如果模型不存在
        if self.model_name is not None:
            self.bertModel = BertForTokenClassification.from_pretrained(self.model_name, **self.model_kwargs)
        else:
            self.bertModel = BertForTokenClassification(self.bert_config)
        self.crf_model = CRF(num_tags=num_tags, batch_first=batch_first)

2.forward方法

输出的结果,经过处理后,输入CRF函数,返回loss即可。

    def forward(self, input_ids: torch.Tensor,
                tags: torch.Tensor = None,
                attention_mask: Optional[torch.ByteTensor] = None,
                token_type_ids=torch.Tensor,
                decode:bool = True,
                reduction: str = 'mean')->List:
        emissions = self.bertModel(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)[0]

        # 去掉开头的[CLS]以及结尾,结尾可能有两种情况:1、<pad> 2、[SEP]
        new_emissions = emissions[:, 1:-1]
        new_mask = attention_mask[:, 2:].bool()

        # tags为None, 是预测过程,不能求loss
        if tags is None:
            loss = None
            pass
        else:
            new_tags = tags[:, 1:-1]
            loss = self.crf_model(emissions=new_emissions, tags=new_tags, mask=new_mask, reduction=reduction)

        if decode:
            tag_list = self.crf_model.decode(emissions=new_emissions, mask=new_mask)
            return [loss, tag_list]
        return [loss]
  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
bert-bilstm-crf模型源码是一种用于命名实体识别的深度学习模型。该模型结合了BERT预训练模型、双向LSTM和CRF(条件随机场)这三种模型结构。首先,模型使用预训练的BERT模型来提取输入句子的语义表示,然后将这些表示传入双向LSTM网络中,以捕捉句子中的序列信息。最后,通过CRF层来进行标记序列的最优化解码,得到最终的命名实体识别结果。 该模型的源码通常由多个部分组成,其中包括构建BERT模型的源码、构建双向LSTM网络的源码、构建CRF层的源码以及整合这三部分模型结构的源码。通过阅读模型源码,可以了解到模型的具体实现细节,包括参数初始化、前向传播和反向传播算法等。同时,也可以根据实际需求对源码进行修改和调整,以适配不同的数据集或任务。 bert-bilstm-crf模型源码通常是使用Python语言编写的,使用深度学习框架如PyTorch或TensorFlow来实现模型构建和训练。其中,BERT模型通常是通过Hugging Face的transformers库加载和使用的。另外,由于使用了深度学习框架,模型的源码还会包括数据预处理、训练和评估的代码部分。 总之,bert-bilstm-crf模型源码是一个宝贵的资源,通过阅读和理解源码,可以深入了解该模型的原理和实现细节,并且可以在实际应用中进行二次开发和优化,从而更好地适应具体的任务和数据。

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值