一、背景
之前自己写了个简单的开源训练框架Bo仔很忙:bert4torch(参考bert4keras的pytorch实现),张罗着给框架不断增加示例,看到了W2NER,于是参考源代码迁移到bert4torch上,并在中文数据集上做了测试。关于W2NER的解读有下面几篇。
NER任务最新SOTA模型W2NER_colourmind的博客-CSDN博客_ner任务
2022 统一NER SOTA模型【W2NER】详解 - 知乎
二、简要思路介绍
1、label的表示
W2NER是能够统一处理扁平实体、重叠实体和非连续实体三种NER任务,这得益于其label的统一表示,如下图所示,其实体内相邻的token用NNW来表示,实体的边界用THW-S来表示。
2、模型结构
模型主要网络结构有,bert层、双向LSTM层、卷积层、CLN层,以及输出层Co-Predictor(由仿射变换+MLP组成),其主要流程如下
- input_ids输入bert层和双向LSTM层,得到
[btz, seqlen, hdsz]
的表示 - 过CLN(条件LayerNorm层),得到
[btz, seqlen, seqlen, hdsz]
的word_embedding - concat上另外两个embedding,距离embedding和区域embedding
- 依次经过卷积层和输出层,得到
[btz, seqlen, seqlen, entnum]
的表示,可以和labels计算交叉熵损失
三、实验比对
在数据集上测试看看,在人民日报数据集上token粒度f1=97.37, ent粒度f1=96.32,具体测试结果如下表(含其他算法的测试结果)
- 人民日报数据集+bert预训练模型
- valid集指标
solution | epoch | f1_token | f1_entity | comment |
---|---|---|---|---|
bert+crf | 18/20 | 96.89 | 96.05 | —— |
bert+crf+init | 18/20 | 96.93 | 96.08 | 用训练数据初始化crf权重 |
bert+crf+freeze | 11/20 | 96.89 | 96.13 | 用训练数据生成crf权重(不训练) |
bert+cascade+crf | 5/20 | 98.10 | 96.26 | crf类别少所以f1_token偏高 |
bert+crf+posseg | 13/20 | 97.32 | 96.55 | 加了词性输入 |
bert+global_pointer | 18/20 | —— | 95.66 | —— |
bert+efficient_global_pointer | 17/20 | —— | 96.55 | —— |
bert+mrc | 7/20 | —— | 95.75 | —— |
bert+span | 13/20 | —— | 96.31 | —— |
bert+tplinker_plus | 20/20 | —— | 95.71 | 长度限制明显 |
uie | 20/20 | —— | 96.57 | zeroshot:f1=60.8, fewshot-100样本:f1=85.82, 200样本:f1=86.40 |
W2NER | 18/20 | 97.37 | 96.32 | 对显存要求较高 |
四、代码
全部代码测试都是基于bert4torch框架,这是一个基于pytorch的训练框架,前期以效仿和实现bert4keras的主要功能为主,特点是尽量简洁轻量,提供丰富示例,有兴趣的小伙伴可以试用,欢迎star。
bert4torch/task_sequence_labeling_ner_W2NER.py at master · Tongjilibo/bert4torch · GitHub