中医药知识图谱NER命名实体识别基于span方式

★★★ 本文源自AI Studio社区精品项目,【点击此处】查看更多精品内容 >>>


第一届全国中医药知识图谱构建与应用大赛初赛命名实体识别第一:基于SPANNER实现命名实体识别模型

初赛命名实体识别

初赛最终成绩: 加权F1值69.21%,排名第一

虽然只是个小比赛,而且网上关于知识图谱相关的模型有很多教学代码,但这对于第一次参加算法比赛的肉鸡来说还是非常有价值的。这个比赛是在2022年6,7月左右,由于个人原因没有参加后续决赛,后续的关系抽取部分没有进行实现,同时这个项目是去年只完成一半,有很多地方今年才开始完善,所以可能会存在不足不对的地方,请大佬们多多赐教

1 项目介绍

1.1 赛题背景

知识图谱是近年来知识管理和知识服务领域中出现的一项新兴技术,它为中医临床知识的关联、整合与分析提供了理想的技术手段。我们基于中医医案等临床知识源,初步建立了由疾病、证候、症状、方剂、中药等核心概念所构成的中医临床知识图谱,以促进中医临床知识的互融互通,揭示中医方证的相关关系,辅助中医临床研究和临床决策。

中医药学是一门古老的医学,历代医家在数千年的实践中积累了丰富的临床经验,形成了完整的知识体系,产生了海量的临床文献。近年来,国家对中医药事业大力扶持,中医药领域的临床实践和临床研究都取得了长足的发展。中医临床方法在国际社会得到广泛认可,传播到183个国家和地区。

利用信息技术手段开展中医临床知识的管理和服务是一项开创性的探索,在临床上具有极大的应用价值。近年来,知识图谱(Knowledge Graph)成为知识管理领域中的一项新兴技术,因其简单易学、可扩展性强、支持智能应用等优点而得到广泛应用。它有助于实现临床指南、中医医案以及方剂知识等各类知识的关联与整合,挖掘整理中医临证经验与学术思想,实现智能化、个性化的中医药知识服务,因此在中医临床领域具有广阔的应用前景。

1.2 赛题任务介绍

以主办方提供的标注语料及相关文段作为生产数据,以网络发表的公开的语料集(如人民日报语料集)作为参照数据,使用基于规则、基于字典、机器学习、深度学习等方法进行命名实体识别。

任务目的是从中医药期刊文献的题目和摘要中识别中医药相关实体,实体类型具体包括:中医诊断、西医诊断、中医证候、临床表现、中医治则、方剂、中药、其他治疗等

提供的训练数据为BIO格式,如:

  • 现 O
  • 头 O
  • 昏 O
  • 口 B-临床表现
  • 苦 I-临床表现

1.3 所需环境

!pip install --upgrade paddlenlp==2.2.6
import paddle
from paddle.io import Dataset
import paddle.nn as nn
import paddle.nn.functional as F

import paddlenlp
from paddlenlp.transformers import LinearDecayWithWarmup
from paddlenlp.data import Stack, Pad, Tuple
from paddlenlp.metrics import ChunkEvaluator
from paddlenlp.layers.crf import LinearChainCrf, LinearChainCrfLoss, ViterbiDecoder
from functools import partial

import numpy as np
import pandas as pd
import time
import os

# # 可视化工具
import seaborn as sns
import matplotlib.pyplot as plt
from visualdl import LogWriter

2 任务一初赛命名实体提取方案介绍

2.1 数据集部分

数据集部分:数据集清洗,包括对部分错标处理、利用词典对漏标进行补充。医案数据对训练集进行补充(这个好像影响不大)

2.2 模型架构部分

三层架构,底层句子表示层,采用ernie-health-chinese百度开源医疗预训练语言模型进行句子向量表示;第二层LSTM层,将第一层输出作为Bi-LSTM层输入让模型学习前后依赖信息;第三层SPAN预测,将第二层LSTM输出(只取序列输出)放到全连接层1预测实体头,输出shape为[batch_size , seq_len , num_labels],然后实体头预测结果和第二层LSTM输出(只取序列输出,输入shape[batch_size , seq_len , hidden_size * 2+1])放到全连接层2预测实体尾。

2.3 模型训练优化部分

根据数据探索性分析,损失函数采用多分类的focal loss(label smooth好像也行但没有实现),降低模型对预测实体标签类别有不同倾向,减少标签分类数量不平衡的影响,提高模型泛化性;fgm/pgd强化训练,训练更平稳,减少模型过拟合情况。优化器采用adamW。

Focal loss是最初由何恺明提出的,最初用于图像领域解决数据不平衡造成的模型性能问题。

这里推荐个博客有关于focal loss二分类和多分类的介绍和实现https://blog.csdn.net/u014311125/article/details/109470137

3 赛事数据处理与分析

3.1 数据集加载

数据集有4个包括:训练集、验证集、测试集和相关医案数据,该部分加载的数据集是经过部分处理(去掉标签中字符’B’、‘I’、‘-’,并把一条所有字符拼接),方便数据初始分析及后续转为span格式

span格式的数据集,如:

  • 现 O O
  • 头 O O
  • 昏 O O
  • 口 临床表现 O
  • 苦 O 临床表现
# 符合span格式的训练集
train_df = pd.read_csv('format_data/train.csv',sep='\t')

# 符合span格式的验证集
dev_df = pd.read_csv('format_data/dev.csv',sep='\t')

# 医案语料
yian_df=pd.read_csv('format_data/yian.csv',sep='\t')

# 符合span格式的测试集
test_df = pd.read_csv('format_data/test.csv',sep='\t')
yian_df['text'].str.len()[:4]
# 医案语料长度
0    1097
1    1257
2     975
3    3095
Name: text, dtype: int64

预处理的标签表读入

命名实体识别对应标签信息读入,标签包括:

O、中医治则、方剂、中医治疗、中医证候、中医诊断、中药、其他治疗、西医治疗、西医诊断、临床表现

labeldict = open('dict/label.txt',mode='r').readlines()
labeldict = [i.strip() for i in labeldict]

label2ids = {x:i for i,x in enumerate(labeldict)}
ids2label = {i:x for i,x in enumerate(labeldict)}

# ids2label、label2ids 用于标签转id、id转标签

3.2 数据集探索性分析

print('数据文本平均长度:',
    '\n  train:',sum(train_df['text'].str.len()/2)/len(train_df),
    '\n  dev:',sum(dev_df['text'].str.len()/2)/len(dev_df),
    '\n  test:',sum(test_df['text'].str.len()/2)/len(test_df))
数据文本平均长度: 
  train: 37.96937999239255 
  dev: 38.00762195121951 
  test: 36.4855403348554
# 数据集中每条数据长度分布,拼接时加了空格符,所以每条数据长度要除2
sns.distplot(train_df['text'].str.len()/2)
sns.distplot(dev_df['text'].str.len()/2)
sns.distplot(test_df['text'].str.len()/2)

# 1.训练集、验证集和测试集每条数据长度同分布
# 2.长度范围为[0,150]
<matplotlib.axes._subplots.AxesSubplot at 0x7f40c97fe310>

在这里插入图片描述

from matplotlib import font_manager
 
plt.rcParams["font.sans-serif"]=["DejaVu"] #设置字体
plt.rcParams["axes.unicode_minus"]=False #该语句解决图像中的“-”负号的乱码问题

font_manager.FontProperties(fname='work/font/SIMHEI.TTF');
def label_count_plot(df):
    label_text = ''
    for i in df['labels'].str.split():
        label_text += ' '.join(i)+' '
    # 统计标签数量
    plt.figure(figsize=(16,9),dpi=200)
    label_count = pd.Series(label_text.split()).value_counts()
    ax = sns.barplot(x=label_count.keys()[1:],y=label_count.values[1:])

    ax.set_xticklabels(
        labels=label_count.keys()[1:],
        fontdict={
            'fontproperties':font_manager.FontProperties(fname='work/font/SIMHEI.TTF')
        },
    )
    ax.set_xticklabels(ax.get_xticklabels(),rotation = 30);
    return label_count
    
# 统计训练集中实体标签数量,不区分实体头或尾等
train_label_count = label_count_plot(train_df)

在这里插入图片描述

# 同理验证集和测试集
# train_label_count = label_count_plot(dev_df)
# train_label_count = label_count_plot(test_df)
def label_count_plot(df):
    label_text = ''
    for i in df['labels'].str.split():
        label_text += ' '.join(i)+' '
    # 统计标签数量
    plt.figure(figsize=(16,9),dpi=200)
    label_count = pd.Series(label_text.split()).value_counts()
    ax = sns.barplot(x=label_count.keys()[1:],y=label_count.values[1:])

    ax.set_xticklabels(
        labels=label_count.keys()[1:],
        fontdict={
            'fontproperties':font_manager.FontProperties(fname='work/font/SIMHEI.TTF')
        },
    )
    ax.set_xticklabels(ax.get_xticklabels(),rotation = 30);
    return label_count

def entity_countplot(label_counter):
    entity_counter=dict()
    for i in label_counter.keys():
        if 'B-' in i:
            entity_counter.setdefault(i,label_counter[i])
    plt.figure(figsize=(16,5),dpi=200)
    ax = sns.barplot(y=list(entity_counter.values()),x=list(entity_counter.keys()))

    ax.set_xticklabels(
        labels=entity_counter.keys(),
        fontdict={'fontproperties':font_manager.FontProperties(fname='work/font/SIMHEI.TTF')},
    )
    ax.set_xticklabels(ax.get_xticklabels(),rotation = 30);
bio_train_df = pd.read_csv('format_data/bio_train.csv',sep='\t')
bio_dev_df = pd.read_csv('format_data/bio_dev.csv',sep='\t')
bio_test_df = pd.read_csv('format_data/bio_test.csv',sep='\t')
entity_countplot(label_count_plot(bio_train_df))

在这里插入图片描述

在这里插入图片描述

entity_countplot(label_count_plot(bio_dev_df))

在这里插入图片描述

在这里插入图片描述

entity_countplot(label_count_plot(bio_test_df))

在这里插入图片描述

在这里插入图片描述

3.3小结

  1. 训练集、验证集和测试集同分布,长度范围为[0,150],数据平均长度约37

  2. 同时通过人工检查还发现数据集存在大量漏标、少量错标的情况,后面在数据预处理阶段可以通过实体标签词典来修正补充训练集

  3. 各标签数量分布非常不均匀,出现最多的的实体是临床表现、西医诊断、中药, 中医诊断、中医治则、其他治疗实体较少,可以考虑补充有相关实体的数据集

4 方案实现

4.1 数据预处理

encoding层句子表示预训练模型使用ernie-health-chinese百度开源医疗预训练语言模型

# 加载预训练模型和模型对应分词工具
pretrained_model = paddlenlp.transformers.AutoModel.from_pretrained('ernie-health-chinese')
tokenizer = paddlenlp.transformers.AutoTokenizer.from_pretrained('ernie-health-chinese')
[2023-02-09 16:32:28,768] [    INFO] - We are using <class 'paddlenlp.transformers.electra.modeling.ElectraModel'> to load 'ernie-health-chinese'.
[2023-02-09 16:32:28,772] [    INFO] - Downloading https://paddlenlp.bj.bcebos.com/models/transformers/ernie-health-chinese/ernie-health-chinese.pdparams and saved to /home/aistudio/.paddlenlp/models/ernie-health-chinese
[2023-02-09 16:32:28,775] [    INFO] - Downloading ernie-health-chinese.pdparams from https://paddlenlp.bj.bcebos.com/models/transformers/ernie-health-chinese/ernie-health-chinese.pdparams
100%|██████████| 392M/392M [00:19<00:00, 20.7MB/s] 
W0209 16:32:48.748164   184 gpu_resources.cc:61] Please NOTE: device: 0, GPU Compute Capability: 7.0, Driver API Version: 11.2, Runtime API Version: 11.2
W0209 16:32:48.752935   184 gpu_resources.cc:91] device: 0, cuDNN Version: 8.2.
[2023-02-09 16:32:51,206] [    INFO] - We are using <class 'paddlenlp.transformers.electra.tokenizer.ElectraTokenizer'> to load 'ernie-health-chinese'.
[2023-02-09 16:32:51,209] [    INFO] - Downloading https://paddlenlp.bj.bcebos.com/models/transformers/ernie-health-chinese/vocab.txt and saved to /home/aistudio/.paddlenlp/models/ernie-health-chinese
[2023-02-09 16:32:51,212] [    INFO] - Downloading vocab.txt from https://paddlenlp.bj.bcebos.com/models/transformers/ernie-health-chinese/vocab.txt
100%|██████████| 109k/109k [00:00<00:00, 2.26MB/s]

4.1.1 文本数据分词tokenize

# 通过预处理获得每个标签对应实体名称,用于修正训练集

def load_dictionary(dictionary_path):
    dictionary = {}
    for i in labeldict[1:]:
        with open(dictionary_path+i+'.txt','r') as f:
            data = (' '.join(f.readlines())).split()
            dictionary[i] = set(data)
    return dictionary
dictionary = load_dictionary('dict/dictionary/')
from utils.BIO2SPAN import BIO2SPAN
from utils.addDict2Span import addDict2Span

# 将字符数据转换为对应向量,span格式只在实体开头第一个字和最后一个字表示
def convert_example(
        text,
        label2ids, 
        labels=None, 
        tokenizer=None, 
        max_seq_length=180, 
        is_test=False,
        dictionary=dictionary,
        is_SPAN=True,
        is_BIO=False
    ):          

    encoded_inputs = tokenizer(
        text=text,
        max_seq_len=max_seq_length,
    )
    input_ids = encoded_inputs["input_ids"]
    token_type_ids = encoded_inputs["token_type_ids"]

    if is_test:
        return input_ids, token_type_ids
    else:
        if is_SPAN:
            label_start, label_end = BIO2SPAN(labels, label2ids, len(input_ids))
            # assert len(input_ids) == len(token_type_ids) == len(label_start) == len(label_end)
            if not dictionary==None:              # 在不修改原本标注下,补充基于span的词典信息
                # 根据词典信息补充漏标实体,考虑到addDict2Span是在项目运行时才运行的,可能会影响程序性能
                # 后续可以优化
                label_start, label_end = addDict2Span(text, dictionary,label_start, label_end, label2ids)
            return input_ids, token_type_ids, label_start, label_end
        elif is_BIO:
            label =  ['O'] + labels.split() + ['O']
            label = [label2ids[i] for i in label]
            return input_ids, token_type_ids, label
idx = 13
input_ids, token_type_ids, label_start, label_end  = convert_example(
    text=train_df['text'][idx],
    label2ids=label2ids,
    labels=train_df['labels'][idx],
    tokenizer=tokenizer
)

print('text:',' '.join(train_df['text'][idx].split()),'  input_ids:',input_ids)
print('label:',train_df['labels'][idx])
print('token_type_ids:',token_type_ids)
print('label_start:',label_start)
print('label_end:',label_end)
text: 1 9 9 9 年 4 月 1 2 日 起 恶 寒 发 热 , 头 痛 身 楚 , 腹 胀 纳 呆 , 溲 黄 便 软 , 谷 丙 转 氨 酶 & g t   input_ids: [2, 20, 746, 746, 746, 151, 405, 53, 20, 249, 362, 108, 649, 890, 61, 273, 1, 147, 95, 194, 794, 1, 276, 523, 1607, 1763, 1, 5102, 348, 202, 525, 1, 1360, 1298, 474, 869, 974, 1, 1287, 1525, 3]
label: O O O O O O O O O O O O O O O O O O O O O O O 临床表现 临床表现 O O O O O O O O O O O O O O
token_type_ids: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
label_start: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 10, 0, 0, 0, 0, 0, 0, 0, 10, 0, 10, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
label_end: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 10, 0, 0, 0, 0, 0, 0, 0, 10, 0, 10, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
# 为了后续方便使用,我们给 convert_example 赋予一些默认参数
# 训练集的样本转换函数 
trans_func = partial(
    convert_example,
    label2ids=label2ids,
    tokenizer=tokenizer,
    max_seq_length=None
)

4.2 自定义dataset、dataLoader

class my_Dataset(Dataset):
    def __init__(self, data, trans_func, is_test=False):
        super(my_Dataset, self).__init__()
        self.data = data
        self.trans_func = trans_func
        self.is_test = is_test
        
    def __getitem__(self, index): 
        if self.is_test == False:
            input_ids, token_type_ids, label_start, label_end = self.trans_func(
                text = self.data['text'][index],
                labels = self.data['labels'][index],
                is_test = self.is_test
            )
            return input_ids, token_type_ids, label_start, label_end

        if self.is_test:
            input_ids, token_type_ids = self.trans_func(
                text = self.data['text'][index],
                is_test = self.is_test
            )
            return input_ids, token_type_ids

    def __len__(self):
        return len(self.data)
train_df = train_df.append(dev_df, ignore_index = True)
train_ds = my_Dataset(train_df, trans_func,is_test=False)
test_ds = my_Dataset(test_df, trans_func,is_test=False)
# 我们的训练数据会返回 input_ids, token_type_ids, seq_len, labels 4 个字段
# 因此针对这 4 个字段需要分别定义 4 个组 batch 操作

# 训练集和验证集batchify_fn
ignore_label = -1 # 将填充的标签设为-1
batchify_fn = lambda samples, fn=Tuple(
    Pad(axis=0, pad_val=tokenizer.pad_token_id),        # input_ids
    Pad(axis=0, pad_val=tokenizer.pad_token_type_id),   # token_type_ids
    Pad(axis=0, pad_val=ignore_label),                  # label_start
    Pad(axis=0, pad_val=ignore_label)                   # label_end
): fn(samples)

# 测试集的batchify_fn
batchify_test_fn = lambda samples, fn=Tuple(
    Pad(axis=0, pad_val=tokenizer.pad_token_id),  # input_ids
    Pad(axis=0, pad_val=tokenizer.pad_token_type_id),  # token_type_ids
): fn(samples)
batch_sampler = paddle.io.DistributedBatchSampler(train_ds, batch_size=128, shuffle=True)
train_dataloader = paddle.io.DataLoader(
        dataset=train_ds,
        batch_sampler=batch_sampler,
        collate_fn=batchify_fn
)

# 定义 test_data_loader
batch_sampler = paddle.io.BatchSampler(test_ds, batch_size=1, shuffle=False)
test_dataloader = paddle.io.DataLoader(
        dataset=test_ds,
        batch_sampler=batch_sampler,
        collate_fn=batchify_fn,
        # collate_fn=batchify_test_fn,
)

4.3 命名实体识别模型搭建

from model.loss.focalloss import FocalLoss
# from model.biaffine import Biaffine1, Biaffine2 # 双仿射变化模块
from model.net import BertLSTMSpanNER, BertSpanNER, BertLstmCrf
from model.ErnieSeqLabelModel import ErnieSeqLabel

from model.AdversarialTrain_block import FGM, PGD
class BertLSTMSpanNER(nn.Layer):
    def __init__(
        self,
        pretrained_model,
        num_labels,
        ignore_label=-1,
        loss_type='ce'):

        super().__init__()
        self.loss_type = loss_type 
        self.num_labels = num_labels
        self.encoder = pretrained_model
        self.dropout = nn.Dropout(0.1)
        self.lstm = nn.LSTM(
            input_size=self.encoder.config['hidden_size'],
            hidden_size=int(self.encoder.config['hidden_size']/2),
            dropout=0.1,
            direction='bidirect'
        )

        self.start_fc = nn.Linear(
            in_features=self.encoder.config['hidden_size'],
            out_features=num_labels,
        )
        self.end_fc = nn.Sequential(
            nn.Linear(self.encoder.config['hidden_size']+1,self.encoder.config['hidden_size']),
            nn.Tanh(),
            nn.LayerNorm(self.encoder.config['hidden_size']),
            nn.Linear(self.encoder.config['hidden_size'],num_labels)
        )
        # 损失函数确定
        assert self.loss_type in ['lsc', 'fcl', 'ce']
        if self.loss_type =='lsc':
            self.loss_fct = LabelSmoothingCrossEntropy()
        elif self.loss_type == 'fcl':
            self.loss_fct = FocalLoss(num_classes=self.num_labels)
        else:
            self.loss_fct = nn.loss.CrossEntropyLoss(ignore_index=ignore_label)

    def forward(
        self,
        input_ids,
        token_type_ids,
        label_start=None,
        label_end=None):

        sequence_out = self.encoder(
            input_ids,
            token_type_ids=token_type_ids
        )
        sequence_out,_ = self.lstm(self.dropout(sequence_out))    # [batch_size , seq_len , hidden_size*2]
        start_logits = self.start_fc(sequence_out)       # 实体头预测 [batch_size , seq_len , num_labels]
        start_preds = paddle.cast(start_logits.argmax(-1).unsqueeze(2), dtype='float32')    # [batch_size , seq_len , 1]
        end_logits = self.end_fc(paddle.concat([sequence_out,start_preds],axis=-1))         # 实体尾预测 输入shape[batch_size , seq_len , hidden_size*2+1]
        outputs = start_logits, end_logits         # 模型结果推断

        # 计算loss
        if label_start is not None and label_end is not None:
            start_loss = self.loss_fct(start_logits, label_start)
            end_loss = self.loss_fct(end_logits, label_end)
            total_loss = (start_loss + end_loss) / 2
            outputs = total_loss
        # 训练时返回loss, 推断时返回logits
        return outputs

对抗训练基本原理,训练过程中在embeding层增加扰动作为一种regularization,提高模型的泛化能力。https://zhuanlan.zhihu.com/p/91269728

# FGM对抗训练step函数
def fgm_step(input_ids,token_type_ids,label_start,label_end):
    fgm.attack() # 在embedding上添加对抗扰动
    loss_adv = model(
        input_ids,
        token_type_ids=token_type_ids,
        label_start=label_start,
        label_end=label_end
    )
    loss_adv.backward() # 反向传播,并在正常的grad基础上,累加对抗训练的梯度
    fgm.restore() # 恢复embedding参数


# PGM对抗训练step函数
def pgm_step( input_ids,token_type_ids,label_start,label_end):
    pgd.backup_grad()
    # 对抗训练
    K = 3
    for t in range(K):
        # 在embedding上添加对抗扰动, first attack时备份param.data
        pgd.attack(is_first_attack=(t==0)) 
        if t != K-1:
            optimizer.clear_grad()
        else:
            pgd.restore_grad()
        loss_adv = model(
            input_ids,
            token_type_ids=token_type_ids,
            label_start=label_start,
            label_end=label_end
        )
        loss_adv.backward() # 反向传播,并在正常的grad基础上,累加对抗训练的梯度
    pgd.restore() # 恢复embedding参数

#超参数设置
loss_type='fcl'
num_epochs = 100
lr = 9e-5
use_AT=2                               # 0 不用对抗训练, 1使用FGM对抗训练,  2使用PGD对抗训练
save_param_path='ckp/best_model/'        # 训练完成的模型存放路径
save_best_threshold=0.85                 # 测试集准确率阈值,超过该阈值则保存模型参数
model = BertLSTMSpanNER(
    pretrained_model=pretrained_model,
    num_labels=len(label2ids),
    loss_type=loss_type
)

if use_AT == 1:
    fgm = FGM(model)
elif use_AT == 2:
    pgd = PGD(model)
for step, batch in enumerate(train_dataloader, start=1):
    input_ids, token_type_ids, label_start, label_end = batch
    loss = model(
        input_ids,
        token_type_ids=token_type_ids,
        label_start=label_start,
        label_end=label_end
    )
    loss.backward()
    # pgm_step(input_ids,token_type_ids,label_start,label_end)
    print(loss)
    if step==5:
        break
Tensor(shape=[1], dtype=float32, place=Place(gpu:0), stop_gradient=False,
       [1.90028715])
Tensor(shape=[1], dtype=float32, place=Place(gpu:0), stop_gradient=False,
       [1.88355613])
Tensor(shape=[1], dtype=float32, place=Place(gpu:0), stop_gradient=False,
       [1.88647342])
Tensor(shape=[1], dtype=float32, place=Place(gpu:0), stop_gradient=False,
       [1.88523901])
Tensor(shape=[1], dtype=float32, place=Place(gpu:0), stop_gradient=False,
       [1.87226939])

4.4 模型评估与训练

# 定义 learning_rate_scheduler,负责在训练过程中对 lr 进行调度
lr_scheduler = LinearDecayWithWarmup(
    learning_rate=lr,
    total_steps=50 * num_epochs,
    warmup=0.1
)
decay_params = [
    p.name for n, p in model.named_parameters()
    if not any(nd in n for nd in ["bias", "norm"])
]
optimizer = paddle.optimizer.AdamW(
    learning_rate=lr_scheduler,
    parameters=model.parameters(),
    apply_decay_param_fun=lambda x: x in decay_params)
from collections import Counter

class SpanEntityMetric(object):
    def __init__(self, ids2label):
        self.ids2label = ids2label
        self.reset()

    def reset(self):
        self.origins = []
        self.founds = []
        self.rights = []

    def compute(self, origin, found, right):
        recall = 0 if origin == 0 else (right / origin)
        precision = 0 if found == 0 else (right / found)
        f1 = 0. if recall + precision == 0 else (2 * precision * recall) / (precision + recall)
        return precision, recall, f1

    def result(self):
        class_info = {}
        # 
        origin_counter = Counter([self.ids2label[x[0]] for x in self.origins])
        found_counter = Counter([self.ids2label[x[0]] for x in self.founds])
        right_counter = Counter([self.ids2label[x[0]] for x in self.rights])
        for type_, count in origin_counter.items():
            origin = count
            found = found_counter.get(type_, 0)
            right = right_counter.get(type_, 0)
            precision,recall, f1 = self.compute(origin, found, right)
            class_info[type_] = {"acc": round(precision, 4), 'recall': round(recall, 4), 'f1': round(f1, 4)}
        origin = len(self.origins)
        found = len(self.founds)
        right = len(self.rights)
        precision, recall,  f1 = self.compute(origin, found, right)
        return {'precision': precision, 'recall': recall, 'f1': f1}, class_info

    def update(self, true_subject, pred_subject):
        self.origins.extend(true_subject)
        self.founds.extend(pred_subject)
        self.rights.extend([pre_entity for pre_entity in pred_subject if pre_entity in true_subject])

metric = SpanEntityMetric(ids2label)
# 从预测的span实体中提取实体及其标签函数
def extract_item(line_start, line_end):
    S = []
    for i, s_l in enumerate(line_start):
        if s_l == label2ids['O'] or s_l == -1:
            continue
        for j, e_l in enumerate(line_end[i:]):
            if s_l == e_l:
                S.append((s_l, i, i + j))
                break
    return S


# 合并邻近且标签相同的实体函数
def combine_item(pred_labels):
    # pred_labels : [(predict_label_ids, predict_label_start, predict_label_end)*n]
    # batch中每条数据按预测实体的start增序排序
    combined_pred = []
    flag_istaken = False
    for index in range(len(pred_labels)):
        if flag_istaken:
            flag_istaken=False
            continue
        elif index==(len(pred_labels)-1):
            combined_pred.append(pred_labels[index])
            break
        
        cur_labelid, cur_start, cur_end = pred_labels[index]
        next_labelid, next_start, next_end = pred_labels[index+1]

        if (cur_labelid == next_labelid) and (cur_end+1 == next_start):
            combined_pred.append((cur_labelid, cur_start, next_end))
            flag_istaken = True
        else:
            combined_pred.append((cur_labelid, cur_start, cur_end))

    return combined_pred
# 模型评估函数
@paddle.no_grad()
def do_evaluate(data_loader):
    metric.reset()#评估器复位
    #依次处理每批数据
    for input_ids, token_type_ids, label_start, label_end in data_loader:
        # 训练时返回loss, 推断时返回logits
        start_logits,end_logits = model(input_ids,token_type_ids=token_type_ids)       
        label_start = label_start.cpu().numpy()[:,1:-1]
        label_end = label_end.cpu().numpy()[:,1:-1]
        start_pred = paddle.argmax(start_logits,-1).cpu().numpy()[:,1:-1]
        end_pred = paddle.argmax(end_logits,-1).cpu().numpy()[:,1:-1]
        for i in range(len(label_start)):
            label = extract_item(label_start[i], label_end[i])
            pred = extract_item(start_pred[i], end_pred[i])
            # pred = combine_item(pred) #  将预测相邻的且标签相同的合并为一个实体
            metric.update(true_subject=label, pred_subject=pred)
    return metric.result()

global_step = 0

for epoch in range(1, num_epochs + 1):
    for step, batch in enumerate(train_dataloader, start=1):
        model.train()
        input_ids, token_type_ids, label_start, label_end = batch
        loss = model(
            input_ids,
            token_type_ids=token_type_ids,
            label_start=label_start,
            label_end=label_end
        )
        loss.backward()
        if use_AT == 1:    # 对抗训练FGM
            fgm_step(input_ids,token_type_ids,label_start,label_end)
        elif use_AT == 2:  # 对抗训练PGD
            pgm_step(input_ids,token_type_ids,label_start,label_end)

        optimizer.step()
        lr_scheduler.step()
        optimizer.clear_grad()

        global_step += 1
        # 每  step评估模型在验证集上的情况
        if global_step % 40 == 0:
            model.eval()
            total_metric, _ = do_evaluate(test_dataloader)
            print(
                "gloabal_step:%d  train-epoch:%d  train_loss: %.6f" % (global_step, epoch, loss)
            )
            print('(eval):',total_metric)

            # 保存最好模型,存储模型参数
            if total_metric['f1'] > save_best_threshold:
                save_best_threshold = total_metric['f1']
                print('saving best model...')
                save_path = os.path.join(
                    save_param_path,
                    str(save_best_threshold)+'ernie_lstm_span_model_state.pdparams'
                )
                paddle.save(model.state_dict(), save_path)

    # 每10个epoch评估一下模型在训练集的情况
    if epoch % 10 == 0:
        model.eval()
        total_metric,_ = do_evaluate(train_dataloader)
        print('(train)',total_metric)

4.5 模型推断


# 推断函数, 返回[batch_size : n*(pred_item_labelid, pred_start, pred_end)]
@paddle.no_grad()
def do_infer(data_loader):
    model.eval()
    preds={i:list() for i in range(len(data_loader))}      # {index : entities}
    for index, batch in enumerate(data_loader):
        if len(batch)==4:
            input_ids, token_type_ids, label_start, label_end = batch
        else:
            input_ids, token_type_ids = batch
        # 训练时返回loss, 推断时返回logits
        start_logits,end_logits = model(input_ids,token_type_ids=token_type_ids) 
        start_pred = paddle.argmax(start_logits,-1).cpu().numpy()[:,1:-1]
        end_pred = paddle.argmax(end_logits,-1).cpu().numpy()[:,1:-1]

        for i in range(len(start_pred)):
            pred = extract_item(start_pred[i], end_pred[i])
            pred = combine_item(pred) # 将预测相邻的且标签相同的合并为一个实体
            preds[index].append(pred)
    return preds

pretrained_path = 'ckp/pretrained_ernie-health-zh'       # ernie_health_zh
dictionary_path = 'dict/dictionary/'
infer_file_path = 'data/infer_bio/'
# trained_model_params_path = 'ckp/best_model/final/ernie_span_model_state.pdparams'
output_path = 'result/'

# # 最好模型参数读入
# best_model_param=paddle.load(trained_model_params_path)
# model.load_dict(trained_model_params)
def vec2result(result):
    result_datalist = []  # entity_name, entity_label, text
    for index in result:
        text = ''.join(test_df['text'][index].split())
        for entity in result[index][0]:
            e_label, e_start, e_end = entity
            entity_name = text[e_start:e_end+1]
            entity_label = ids2label[e_label]
            result_datalist.append((entity_name, entity_label, text))
    return result_datalist
# 开始推断
result = do_infer(test_dataloader)
output = vec2result(result)
# 结果保存
output_df = pd.DataFrame(output,columns=['entity_name','entity_label', 'text'])
output_df.to_csv(os.path.join(output_path,'result.csv'),index=False,encoding='utf8')
metric_result = do_evaluate(test_dataloader)
print(metric_result)
({'precision': 0.8323500491642084, 'recall': 0.8593908629441624, 'f1': 0.8456543456543456}, {'中医诊断': {'acc': 0.7922, 'recall': 0.7176, 'f1': 0.7531}, '西医治疗': {'acc': 0.8478, 'recall': 0.9512, 'f1': 0.8966}, '西医诊断': {'acc': 0.8796, 'recall': 0.9385, 'f1': 0.9081}, '中医治则': {'acc': 0.7273, 'recall': 0.678, 'f1': 0.7018}, '中医治疗': {'acc': 0.9255, 'recall': 0.949, 'f1': 0.9371}, '临床表现': {'acc': 0.737, 'recall': 0.755, 'f1': 0.7459}, '中医证候': {'acc': 0.8909, 'recall': 0.9032, 'f1': 0.897}, '其他治疗': {'acc': 0.8235, 'recall': 0.8235, 'f1': 0.8235}, '方剂': {'acc': 0.8054, 'recall': 0.8869, 'f1': 0.8442}, '中药': {'acc': 0.8818, 'recall': 0.9122, 'f1': 0.8968}})
output_df
entity_nameentity_labeltext
0黄疸中医诊断药进10帖,黄疸稍退,饮食稍增,精神稍振
1法莫替丁西医治疗加味左金丸联合法莫替丁治疗胃食管反流病临床观察
2胃食管反流病西医诊断加味左金丸联合法莫替丁治疗胃食管反流病临床观察
3疏肝行气中医治则“疏肝行气,调神解郁”推拿法结合西药治疗腹泻型ibs的临床疗效
4调神解郁中医治则“疏肝行气,调神解郁”推拿法结合西药治疗腹泻型ibs的临床疗效
............
2029香砂养胃丸方剂方法:将96例寒热错杂型慢性萎缩性胃炎患者随机分为2组,各48例,对照组服用香砂养胃丸治疗,...
2030半夏泻心汤加味方剂方法:将96例寒热错杂型慢性萎缩性胃炎患者随机分为2组,各48例,对照组服用香砂养胃丸治疗,...
2031功能性消化不良西医诊断免煎中药柴平舒胃汤加减治疗功能性消化不良的临床研究
2032柴胡疏肝散方剂柴胡疏肝散对功能性消化不良患者胃动力及胃肠激素的影响
2033功能性消化不良西医诊断柴胡疏肝散对功能性消化不良患者胃动力及胃肠激素的影响

2034 rows × 3 columns

4.6 不同方案测试集F1值比较(部分)

方案(不包括对抗训练和数据集修正)测试集F1值
BERT+LSTM+CRF(baseline)0.73919
Ernie-health-ch+Bi-LSTM+CRF(BIO)0.78621
Ernie-health-ch+MLP(SPAN)0.80161
Nezha-wwm-large-chinese+Bi-LSTM+SPAN_predict(focal loss)0.80034
Ernie-health-ch+Bi-LSTM+SPAN_predict(focal loss)0.81412

对抗训练FMG/PGD提升1个点左右,数据集修正(补充漏标为主)提升3-4个点

在主办方第一次发放中测试集F1值为0.85左右,第二次最终测试集加权F1值为0.6921,初赛第一

5 总结

22年7月左右的小比赛,肉鸡选手第一次参加,其实主要是报着学习的心得参加,积极收集各种资料,通过这次比赛也学会了很多模型搭建调优,更发现了更多自己的不足,如写代码时粗心大意,可以多使用assert检查自己代码,模型搭建也比较不熟手,可以多用shape观察每层输入输出情况。

模型很重要,数据集也很重要!

方案还有更多可以优化之处,如将输出字符偏旁部首拆字作为补充输入信息,将医案、期刊等语料给编码层做预预训练,模型融合、K折交叉训练、伪标签等等,但由于能力和时间有限没法完成。

最后,项目代码是去年6月份写的,其余是今年补充,所以难免存在错误,请大佬们多多指正0.0

参考资料

[1] 【Paddle打比赛】产品评论观点提取竞赛 baseline https://aistudio.baidu.com/aistudio/projectdetail/2417709

[2] CCKS2021阿里天池地址识别-BERT+BiGRU+CRF https://aistudio.baidu.com/aistudio/projectdetail/2272540

[3] 应用BERT模型做命名实体识别任务(Paddle2.0) https://aistudio.baidu.com/aistudio/projectdetail/1477098

[4] 你的CRF层的学习率可能不够大 https://kexue.fm/archives/7196

[5] 基于实体首尾指针SPAN的序列标注框架https://github.com/wzzzd/lm_ner_span

[6] 【炼丹技巧】功守道:NLP中的对抗训练 + PyTorch实现
https://zhuanlan.zhihu.com/p/91269728

[7] 【炼丹技巧】功守道:NLP中的对抗训练 + PyTorch实现
https://zhuanlan.zhihu.com/p/91269728

[8] Focal Loss的理解以及在多分类任务上的使用(Pytorch)https://blog.csdn.net/u014311125/article/details/109470137

[9] 中医临床知识图谱的构建与应用https://x.cnki.net/xmlRead/xml.html?pageType=web&fileName=KJXS201704018&tableName=CJFDTOTAL&dbCode=CJFD&topic=&fileSourceType=1&taskId=&from=&groupId=&appId=CRSP_BASIC_PSMC&act=&customReading=

  • 2
    点赞
  • 16
    收藏
    觉得还不错? 一键收藏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值