常见ner解决方案简单汇总

bert+crf成为了目前ner领域的常见baseline,那常见的ner算法还有哪些呢?本文总结了常见的几种ner算法,并在数据集上进行指标对比。

方案介绍

1. bert+crf

 bert4torch/task_sequence_labeling_ner_crf.py at master · Tongjilibo/bert4torch · GitHub参考bert4keras的pytorch实现. Contribute to Tongjilibo/bert4torch development by creating an account on GitHub.https://github.com/Tongjilibo/bert4torch/blob/master/examples/sequence_labeling/task_sequence_labeling_ner_crf.py

最常见的bert输出[btz, seq_len, hdsz]的隐含层向量,过[hdsz, num_labels]的全连接得到[btz, seq_len, num_labels],即发射分数,然后过crf。

2. bert+cascade+crf(级联ner)

bert4torch/task_sequence_labeling_ner_cascade_crf.py at master · Tongjilibo/bert4torch · GitHub参考bert4keras的pytorch实现. Contribute to Tongjilibo/bert4torch development by creating an account on GitHub.https://github.com/Tongjilibo/bert4torch/blob/master/examples/sequence_labeling/task_sequence_labeling_ner_cascade_crf.py

级联ner,第一阶段(crf阶段)只识别BMES的实体边界,第二阶段根据实体边界pooling得到实体的向量,这里可以有很多方法,如实体首尾average,首的embedding,尾的embedding,全部token的average等;在后面接[hdsz, num_labels]的全连接做分类。

优点:降低crf类别较多时候学习的难度(crf阶段仅需要识别BMES),二阶段输出的实体概率可根据具体场景调整阈值,来实现precision和recall的控制

3. global_pointer(基于内积的token-pair识别模块)

bert4torch/task_sequence_labeling_ner_global_pointer.py at master · Tongjilibo/bert4torch · GitHub参考bert4keras的pytorch实现. Contribute to Tongjilibo/bert4torch development by creating an account on GitHub.https://github.com/Tongjilibo/bert4torch/blob/master/examples/sequence_labeling/task_sequence_labeling_ner_global_pointer.py

GlobalPointer:用统一的方式处理嵌套和非嵌套NER - 科学空间|Scientific Spaces

主要步骤

  • bert输出[btz, seq_len, hdsz],经过[hdsz, heads*head_size*2]的dense层, heads就是实体种类数
  • reshape得到[btz, seq_len, heads, head_size]的qw和kw
  • qw和kw通过RoPE相对位置编码
  • qw和kw内积,得到[btz, heads, seq_len, seq_len]
  • 排除padding,排除下三角(不含对角线),scale
  • [btz, heads,seq_len, seq_len]的label计算CrossEntropyLoss

4. efficient_global_pointer

bert4torch/task_sequence_labeling_ner_efficient_global_pointer.py at master · Tongjilibo/bert4torch · GitHub参考bert4keras的pytorch实现. Contribute to Tongjilibo/bert4torch development by creating an account on GitHub.https://github.com/Tongjilibo/bert4torch/blob/master/examples/sequence_labeling/task_sequence_labeling_ner_efficient_global_pointer.py

Efficient GlobalPointer:少点参数,多点效果

主要步骤

  • bert输出[btz, seq_len, hdsz],经过[hdsz, head_size*2]的dense层得到seq_output
  • reshape得到[btz, seq_len, head_size]的qw和kw
  • qw和kw通过RoPE相对位置编码
  • qw和kw内积并scale,得到[btz, seq_len, seq_len]的logits(是否是实体的打分)
  • [btz, seq_len, head_size*2]的seq_output[head_size*2, heads*2]的dense层,得到[btz, seq_len, heads*2]
  • reshape成[btz, heads, seq_len, 2]的bias(实体类别的打分)
  • logits和bias相加([btz, 1, seq_len, seq_len] + [btz, heads, seq_len, 1] + [btz, heads, 1, seq_len]
  • 排除padding,排除下三角(不含对角线)
  • [btz, heads,seq_len, seq_len]的label计算CrossEntropyLoss

5. ner_mrc(阅读理解)

bert4torch/task_sequence_labeling_ner_mrc.py at master · Tongjilibo/bert4torch · GitHub参考bert4keras的pytorch实现. Contribute to Tongjilibo/bert4torch development by creating an account on GitHub.https://github.com/Tongjilibo/bert4torch/blob/master/examples/sequence_labeling/task_sequence_labeling_ner_mrc.py

比如实体类型有3类,人名,地址名,机构名,则每个sentence前加上短语“找出下述句子中的地址名“来识别地址名称,以此类推。

主要步骤

  • 按如上方式整理数据集,一条样本拆分成多条(实体类别数)
  • 有一些全连接+激活的中间层
  • 最后两个[hdsz, 2]的全连接用来标注实体起始和结束的logit,[btz, seq_len, 2]
  • 和真实的start_label: [btz, seq_len]和end_label:[btz, seq_len]计算CELoss

6. ner_span(指针网络)

bert4torch/task_sequence_labeling_ner_span.py at master · Tongjilibo/bert4torch · GitHub参考bert4keras的pytorch实现. Contribute to Tongjilibo/bert4torch development by creating an account on GitHub.https://github.com/Tongjilibo/bert4torch/blob/master/examples/sequence_labeling/task_sequence_labeling_ner_span.py

两个全连接用来标记实体起始和结束

主要步骤

  • bert输出[btz, seq_len, hdsz]
  • 有一些全连接+激活的中间层
  • 最后两个[hdsz, ent_type_num+1]的全连接用来标注实体的种类,和起始和结束
  • 和真实的start_label: [btz, seq_len]和end_label:[btz, seq_len]计算CELoss

7. tplinker_plus

https://github.com/Tongjilibo/bert4torch/blob/master/examples/sequence_labeling/task_sequence_labeling_ner_tplinker_plus.pyicon-default.png?t=M7J4https://github.com/Tongjilibo/bert4torch/blob/master/examples/sequence_labeling/task_sequence_labeling_ner_tplinker_plus.py

也是一种token_pair的一种,从关系提取中修改过来

主要步骤

  • bert输出[btz, seq_len, hdsz]
  • seq_len中第i个token的embedding需要和>i的token建联,可以是concat+dense或者其他更复杂的,总之就是会得到[btz, seq_len*(seq_len+1)/2,hdsz]的输出,这里pair_len=seq_len*(seq_len+1)/2
  • [hdsz, num_labels]的全连接
  • [btz, pair_len, num_labels]的label计算loss

评测

  • 人民日报数据集+bert预训练模型
  • valid集指标

bert4torch

本测试全部基于bert4torch开发,长期维护,欢迎star

https://github.com/Tongjilibo/bert4torch​github.com/Tongjilibo/bert4torchhttps://link.zhihu.com/?target=https%3A//github.com/Tongjilibo/bert4torch

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值