bert+crf成为了目前ner领域的常见baseline,那常见的ner算法还有哪些呢?本文总结了常见的几种ner算法,并在数据集上进行指标对比。
方案介绍
1. bert+crf
bert4torch/task_sequence_labeling_ner_crf.py at master · Tongjilibo/bert4torch · GitHub参考bert4keras的pytorch实现. Contribute to Tongjilibo/bert4torch development by creating an account on GitHub.https://github.com/Tongjilibo/bert4torch/blob/master/examples/sequence_labeling/task_sequence_labeling_ner_crf.py
最常见的bert输出[btz, seq_len, hdsz]
的隐含层向量,过[hdsz, num_labels]
的全连接得到[btz, seq_len, num_labels]
,即发射分数,然后过crf。
2. bert+cascade+crf(级联ner)
级联ner,第一阶段(crf阶段)只识别BMES的实体边界,第二阶段根据实体边界pooling得到实体的向量,这里可以有很多方法,如实体首尾average,首的embedding,尾的embedding,全部token的average等;在后面接[hdsz, num_labels]的全连接做分类。
优点:降低crf类别较多时候学习的难度(crf阶段仅需要识别BMES),二阶段输出的实体概率可根据具体场景调整阈值,来实现precision和recall的控制
3. global_pointer(基于内积的token-pair识别模块)
GlobalPointer:用统一的方式处理嵌套和非嵌套NER - 科学空间|Scientific Spaces
主要步骤
- bert输出
[btz, seq_len, hdsz]
,经过[hdsz, heads*head_size*2]
的dense层, heads就是实体种类数 - reshape得到
[btz, seq_len, heads, head_size]
的qw和kw - qw和kw通过RoPE相对位置编码
- qw和kw内积,得到
[btz, heads, seq_len, seq_len]
- 排除padding,排除下三角(不含对角线),scale
- 和
[btz, heads,seq_len, seq_len]
的label计算CrossEntropyLoss
4. efficient_global_pointer
Efficient GlobalPointer:少点参数,多点效果
主要步骤
- bert输出
[btz, seq_len, hdsz]
,经过[hdsz, head_size*2]
的dense层得到seq_output - reshape得到
[btz, seq_len, head_size]
的qw和kw - qw和kw通过RoPE相对位置编码
- qw和kw内积并scale,得到
[btz, seq_len, seq_len]
的logits(是否是实体的打分) - 对
[btz, seq_len, head_size*2]
的seq_output过[head_size*2, heads*2]
的dense层,得到[btz, seq_len, heads*2]
- reshape成
[btz, heads, seq_len, 2]
的bias(实体类别的打分) - logits和bias相加(
[btz, 1, seq_len, seq_len]
+[btz, heads, seq_len, 1]
+[btz, heads, 1, seq_len]
) - 排除padding,排除下三角(不含对角线)
- 和
[btz, heads,seq_len, seq_len]
的label计算CrossEntropyLoss
5. ner_mrc(阅读理解)
比如实体类型有3类,人名,地址名,机构名,则每个sentence前加上短语“找出下述句子中的地址名“来识别地址名称,以此类推。
主要步骤
- 按如上方式整理数据集,一条样本拆分成多条(实体类别数)
- 有一些全连接+激活的中间层
- 最后两个
[hdsz, 2]
的全连接用来标注实体起始和结束的logit,[btz, seq_len, 2]
- 和真实的start_label:
[btz, seq_len]
和end_label:[btz, seq_len]
计算CELoss
6. ner_span(指针网络)
两个全连接用来标记实体起始和结束
主要步骤
- bert输出
[btz, seq_len, hdsz]
- 有一些全连接+激活的中间层
- 最后两个
[hdsz, ent_type_num+1]
的全连接用来标注实体的种类,和起始和结束 - 和真实的start_label:
[btz, seq_len]
和end_label:[btz, seq_len]
计算CELoss
7. tplinker_plus
也是一种token_pair的一种,从关系提取中修改过来
主要步骤
- bert输出
[btz, seq_len, hdsz]
- seq_len中第i个token的embedding需要和>i的token建联,可以是concat+dense或者其他更复杂的,总之就是会得到
[btz, seq_len*(seq_len+1)/2,hdsz]
的输出,这里pair_len=seq_len*(seq_len+1)/2
- 过
[hdsz, num_labels]
的全连接 - 和
[btz, pair_len, num_labels]
的label计算loss
评测
- 人民日报数据集+bert预训练模型
- valid集指标
bert4torch
本测试全部基于bert4torch开发,长期维护,欢迎star