KBQA-Bert学习记录-构建BERT-CRF模型

目录

1.__init__方法

2.forward方法


将bert和crf模型结合起来,简单来说就是,设置好Bert模型,以及参数,得到的输出结果给CRF模型即可。

1.__init__方法

这里面主要是bert的参数的定义及导入,还有bert模型的导入。

MODEL_NAME = 'bert-base-chinese-model.bin'
CONFIG_NAME = 'bert-base-chinese-config.json'
VOB_NAME = 'bert-base-chinese-vocab.txt'


class BertCrf(nn.Module):
    def __init__(self, config_name: str, model_name:str = None, num_tags: int = 2, batch_first: bool = True) -> None:
        self.batch_first = batch_first
        # 模型配置文件、模型预训练参数文件判断
        if not os.path.exists(config_name):
            raise ValueError(
                "未找到模型配置文件 '{}'".format(config_name)
            )
        else:
            self.config_name = config_name
        if model_name is not None:
            if not os.path.exists(model_name):
                raise ValueError(
                    "未找到模型预训练参数文件 '{}'".format(model_name)
                )
            else:
                self.model_name = model_name
        else:
            self.model_name = None
        if num_tags <= 0:
            raise ValueError(f'invalid number of tags: {num_tags}')
        super().__init__()

        # 配置bert的config文件
        self.bert_config = BertConfig.from_pretrained(self.config_name)
        self.bert_config.num_labels = num_tags
        self.model_kwargs = {'config': self.bert_config}

        # 如果模型不存在
        if self.model_name is not None:
            self.bertModel = BertForTokenClassification.from_pretrained(self.model_name, **self.model_kwargs)
        else:
            self.bertModel = BertForTokenClassification(self.bert_config)
        self.crf_model = CRF(num_tags=num_tags, batch_first=batch_first)

2.forward方法

输出的结果,经过处理后,输入CRF函数,返回loss即可。

    def forward(self, input_ids: torch.Tensor,
                tags: torch.Tensor = None,
                attention_mask: Optional[torch.ByteTensor] = None,
                token_type_ids=torch.Tensor,
                decode:bool = True,
                reduction: str = 'mean')->List:
        emissions = self.bertModel(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)[0]

        # 去掉开头的[CLS]以及结尾,结尾可能有两种情况:1、<pad> 2、[SEP]
        new_emissions = emissions[:, 1:-1]
        new_mask = attention_mask[:, 2:].bool()

        # tags为None, 是预测过程,不能求loss
        if tags is None:
            loss = None
            pass
        else:
            new_tags = tags[:, 1:-1]
            loss = self.crf_model(emissions=new_emissions, tags=new_tags, mask=new_mask, reduction=reduction)

        if decode:
            tag_list = self.crf_model.decode(emissions=new_emissions, mask=new_mask)
            return [loss, tag_list]
        return [loss]
  • 1
    点赞
  • 1
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值