【Bert + BiLSTM + CRF】实现实体命名识别,后续封装Dataset,DataLoader,进行批次训练

上次介绍了【Bert + BiLSTM + CRF】实现实体命名识别的简单应用,只使用了单个例子跑,这次接着上回继续更新,封装了一下Dataset,并进行了批量数据的训练。本项目使用的标注好的数据集可以私信找我要哦!全程无bug跑完!

项目结构:

在这里插入图片描述
bert-base-chinese: 存放了bert模型,vocab.txt ,config.json
data: 标注好的数据
output:输出的日志文件和模型文件
dataSet.py: 数据预处理代码
main.py: 训练和验证的代码

直接上代码,后面给讲解

dataSet.py

from torch.utils.data import Dataset,DataLoader
from transformers import BertTokenizer
import torch
import warnings
import os
import json
import sys
import re
warnings.filterwarnings('ignore')

def collect_data(path,original_value,result_value,a,b,c,d,e,f):
    with open(path,'r',encoding='utf-8') as file:
        s = json.load(file)
        # 组织学分型, 癌结节, 两侧切缘是否有癌浸润, pCRM, 脉管, 神经 -> a,b,c,d,e,f
        try:
            for i,k in enumerate(s):
                if k=='originalValue':
                    original_value.append(s['originalValue'])
                if k=='resultValue' and s['resultValue']!='':
                    result_value.append(s['resultValue'])
                if k=='classify':
                    classify_data = s[k]
                    a.append(classify_data['组织学分型']) if "组织学分型" in classify_data else a.append(" ")
                    b.append(classify_data['癌结节']) if "癌结节" in classify_data else b.append(" ")
                    c.append(classify_data['两侧切缘是否有癌浸润']) if "两侧切缘是否有癌浸润" in classify_data else c.append(" ")
                    d.append(classify_data['pCRM']) if "pCRM" in classify_data else d.append(" ")
                    e.append(classify_data['脉管']) if "脉管" in classify_data else e.append(" ")
                    f.append(classify_data['神经']) if "神经" in classify_data else f.append(" ")
        except Exception:
            print(f'Errors occus at path : {path}, key : "{k}", with reasons : {sys.exc_info()}')
    return original_value,result_value,a,b,c,d,e,f

def fun4Word(data):
    output = ''
    for i in data:
        word = ''
        label = ''
        word_label = re.split(r'(\[[^\]]+\]/aj_lcjl|\[[^\]]+\]/aj_hzjl|\[[^\]]+\]/lbj_z|\[[^\]]+\]/lbj_y|\[[^\]]+\]/lbj_fz|\[[^\]]+\]/mlh1|\[[^\]]+\]/msh2|\[[^\]]+\]/msh6|\[[^\]]+\]/pms2|\[[^\]]+\]/ki67|\[[^\]]+\]/p53)',i)
        for f in word_label:
            if 'lbj_y' in f:
                word_index = f[1:-7]
                if len(word_index)>1:
                    label_index = "B_lbjy "+(len(word_index)-2)*'M_lbjy '+"E_lbjy "
                else:
                    label_index = "W_lbjy "
                word += word_index
                label += label_index
            elif 'lbj_z' in f:
                word_index = f[1:-7]
                if len(word_index) > 1:
                    label_index = "B_lbjz " + (len(word_index) - 2)* 'M_lbjz '+ "E_lbjz "
                else:
                    label_index = "W_lbjz "
                word += word_index
                label += label_index
            elif 'lbj_fz' in f:
                word_index = f[1:-8]
                if len(word_index) > 1:
                    label_index = "B_lbjfz " + (len(word_index) - 2)*'M_lbjfz ' + "E_lbjfz "
                else:
                    label_index = "W_lbjfz "
                word += word_index
                label += label_index
            elif 'aj_lcjl' in f:
                word_index = f[1:-9]
                if 'cm' in word_index:
                    if len(word_index) > 3:
                        label_index = "B_ajl " + (len(word_index) - 4) * 'M_ajl ' + "E_ajl "+"O "*2
                    else:
                        label_index = "W_ajl " +"O "*2
                elif 'c' in word_index:
                    if len(word_index) > 2:
                        label_index = "B_ajl " + (len(word_index) - 2) * 'M_ajl ' + "E_ajl " +'O '
                    else:
                        label_index = "W_ajl " +'O '
                else:
                    if len(word_index) > 1:
                        label_index = "B_ajl " + (len(word_index) - 2) * 'M_ajl ' + "E_ajl "
                    else:
                        label_index = "W_ajl "
                word += word_index
                label += label_index
            elif 'aj_hzjl' in f:
                word_index = f[1:-9]
                if 'cm' in word_index:
                    if len(word_index) > 3:
                        label_index = "B_ajh " + (len(word_index) - 4) * 'M_ajh ' + "E_ajh " + "O " * 2
                    else:
                        label_index = "W_ajh " + "O " * 2
                elif 'c' in word_index:
                    if len(word_index) > 2:
                        label_index = "B_ajh " + (len(word_index) - 2) * 'M_ajh ' + "E_ajh " + 'O '
                    else:
                        label_index = "W_ajh " + 'O '
                else:
                    if len(word_index) > 1:
                        label_index = "B_ajh " + (len(word_index) - 2) * 'M_ajh ' + "E_ajh "
                    else:
                        label_index = "W_ajh "
                word += word_index
                label += label_index
            elif 'mlh1' in f:
                word_index = f[1:-6]
                if len(word_index) > 1:
                    label_index = "B_mlh1 " + (len(word_index) - 2) * 'M_mlh1 ' + "E_mlh1 "
                else:
                    label_index = "W_mlh1 "
                word += word_index
                label += label_index
            elif 'msh2' in f:
                word_index = f[1:-6]
                if len(word_index) > 1:
                    label_index = "B_msh2 " + (len(word_index) - 2) * 'M_msh2 ' + "E_msh2 "
                else:
                    label_index = "W_msh2 "
                word += word_index
                label += label_index
            elif 'msh6' in f:
                word_index = f[1:-6]
                if len(word_index) > 1:
                    label_index = "B_msh6 " + (len(word_index) - 2) * 'M_msh6 ' + "E_msh6 "
                else:
                    label_index = "W_msh6 "
                word += word_index
                label += label_index
            elif 'pms2' in f:
                word_index = f[1:-6]
                if len(word_index) > 1:
                    label_index = "B_pms2 " + (len(word_index) - 2) * 'M_pms2 ' + "E_pms2 "
                else:
                    label_index = "W_pms2 "
                word += word_index
                label += label_index
            elif 'ki67' in f:
                word_index = f[1:-6]
                if len(word_index) > 1:
                    label_index = "B_ki67 " + (len(word_index) - 2) * 'M_ki67 ' + "E_ki67 "
                else:
                    label_index = "W_ki67 "
                word += word_index
                label += label_index
            elif 'p53' in f:
                # word_index = f[1:-6]
                word_index = f[1:-5]                           #2020-08-19修改,原值在上一行
                if len(word_index) > 1:
                    label_index = "B_p53 " + (len(word_index) - 2) * 'M_p53 ' + "E_p53 "
                else:
                    label_index = "W_p53 "
                word += word_index
                label += label_index
            else:
                word += f
                label +=len(f)*"O "
        if word !='':
            output += word + ' //' + label + '\n'
    return output

def label_process(data):
    word = ''
    label = ''
    word_label = re.split(
        r'(\[[^\]]+\]/aj_lcjl|\[[^\]]+\]/aj_hzjl|\[[^\]]+\]/lbj_z|\[[^\]]+\]/lbj_y|\[[^\]]+\]/lbj_fz|\[[^\]]+\]/mlh1|\[[^\]]+\]/msh2|\[[^\]]+\]/msh6|\[[^\]]+\]/pms2|\[[^\]]+\]/ki67|\[[^\]]+\]/p53)',data)
    for f in word_label:
        if 'lbj_y' in f:
            word_index = f[1:-7]
            if len(word_index) > 1:
                label_index = "B_lbjy " + (len(word_index) - 2) * 'M_lbjy ' + "E_lbjy "
            else:
                label_index = "W_lbjy "
            word += word_index
            label += label_index
        elif 'lbj_z' in f:
            word_index = f[1:-7]
            if len(word_index) > 1:
                label_index = "B_lbjz " + (len(word_index) - 2) * 'M_lbjz ' + "E_lbjz "
            else:
                label_index = "W_lbjz "
            word += word_index
            label += label_index
        elif 'lbj_fz' in f:
            word_index = f[1:-8]
            if len(word_index) > 1:
                label_index = "B_lbjfz " + (len(word_index) - 2) * 'M_lbjfz ' + "E_lbjfz "
            else:
                label_index = "W_lbjfz "
            word += word_index
            label += label_index
        elif 'aj_lcjl' in f:
            word_index = f[1:-9]
            if 'cm' in word_index:
                if len(word_index) > 3:
                    label_index = "B_ajl " + (len(word_index) - 4) * 'M_ajl ' + "E_ajl " + "O " * 2
                else:
                    label_index = "W_ajl " + "O " * 2
            elif 'c' in word_index:
                if len(word_index) > 2:
                    label_index = "B_ajl " + (len(word_index) - 2) * 'M_ajl ' + "E_ajl " + 'O '
                else:
                    label_index = "W_ajl " + 'O '
            else:
                if len(word_index) > 1:
                    label_index = "B_ajl " + (len(word_index) - 2) * 'M_ajl ' + "E_ajl "
                else:
                    label_index = "W_ajl "
            word += word_index
            label += label_index
        elif 'aj_hzjl' in f:
            word_index = f[1:-9]
            if 'cm' in word_index:
                if len(word_index) > 3:
                    label_index = "B_ajh " + (len(word_index) - 4) * 'M_ajh ' + "E_ajh " + "O " * 2
                else:
                    label_index = "W_ajh " + "O " * 2
            elif 'c' in word_index:
                if len(word_index) > 2:
                    label_index = "B_ajh " + (len(word_index) - 2) * 'M_ajh ' + "E_ajh " + 'O '
                else:
                    label_index = "W_ajh " + 'O '
            else:
                if len(word_index) > 1:
                    label_index = "B_ajh " + (len(word_index) - 2) * 'M_ajh ' + "E_ajh "
                else:
                    label_index = "W_ajh "
            word += word_index
            label += label_index
        elif 'mlh1' in f:
            word_index = f[1:-6]
            if len(word_index) > 1:
                label_index = "B_mlh1 " + (len(word_index) - 2) * 'M_mlh1 ' + "E_mlh1 "
            else:
                label_index = "W_mlh1 "
            word += word_index
            label += label_index
        elif 'msh2' in f:
            word_index = f[1:-6]
            if len(word_index) > 1:
                label_index = "B_msh2 " + (len(word_index) - 2) * 'M_msh2 ' + "E_msh2 "
            else:
                label_index = "W_msh2 "
            word += word_index
            label += label_index
        elif 'msh6' in f:
            word_index = f[1:-6]
            if len(word_index) > 1:
                label_index = "B_msh6 " + (len(word_index) - 2) * 'M_msh6 ' + "E_msh6 "
            else:
                label_index = "W_msh6 "
            word += word_index
            label += label_index
        elif 'pms2' in f:
            word_index = f[1:-6]
            if len(word_index) > 1:
                label_index = "B_pms2 " + (len(word_index) - 2) * 'M_pms2 ' + "E_pms2 "
            else:
                label_index = "W_pms2 "
            word += word_index
            label += label_index
        elif 'ki67' in f:
            word_index = f[1:-6]
            if len(word_index) > 1:
                label_index = "B_ki67 " + (len(word_index) - 2) * 'M_ki67 ' + "E_ki67 "
            else:
                label_index = "W_ki67 "
            word += word_index
            label += label_index
        elif 'p53' in f:
            # word_index = f[1:-6]
            word_index = f[1:-5]  # 2020-08-19修改,原值在上一行
            if len(word_index) > 1:
                label_index = "B_p53 " + (len(word_index) - 2) * 'M_p53 ' + "E_p53 "
            else:
                label_index = "W_p53 "
            word += word_index
            label += label_index
        else:
            word += f
            label += len(f) * "O "
    return word,label

def my_collate(data):
    inputs, labels = [],[]
    for i,dat in enumerate(data):
        (input,label) = dat
        inputs.append(input)
        labels.append(label)
    return torch.tensor(inputs),torch.tensor(labels)

class MyDataSet(Dataset):

    def __init__(self,max_length = 512):
        # parameters
        labels = ['B_lbjy', 'M_lbjy', 'E_lbjy', 'W_lbjy', 'B_lbjz', 'M_lbjz', 'E_lbjz', 'W_lbjz', 'B_lbjfz', 'M_lbjfz',
                  'E_lbjfz', 'W_lbjfz',
                  'B_ajl', 'M_ajl', 'E_ajl', 'W_ajl', 'B_ajh', 'M_ajh', 'E_ajh', 'W_ajh', 'B_mlh1', 'M_mlh1', 'E_mlh1',
                  'W_mlh1', 'B_msh2', 'M_msh2', 'E_msh2', 'W_msh2',
                  'B_msh6', 'M_msh6', 'E_msh6', 'W_msh6', 'B_pms2', 'M_pms2', 'E_pms2', 'W_pms2', 'B_ki67', 'M_ki67',
                  'E_ki67', 'W_ki67', 'B_p53', 'M_p53', 'E_p53', 'W_p53', 'O']
        self.tag_num = len(labels)
        original_value, result_value, a, b, c, d, e, f = [], [], [], [], [], [], [], []
        count = 0
        root = 'data'
        tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')
        # collect data
        root_path = os.listdir(root)
        for path in root_path:
            father_path = os.path.join(root, path)
            child_paths = os.listdir(os.path.join(root, path))
            for child_path in child_paths:
                count += 1
                original_value, result_value, a, b, c, d, e, f = collect_data(os.path.join(father_path, child_path),
                                                                              original_value, result_value, a, b, c, d,
                                                                              e, f)
                if result_value=='': print(f'result_value null: count :{count}, path:{os.path.join(father_path, child_path)}')
        print(f'Data Collection Info : original_value : {len(original_value)} result_value : {len(result_value)} '
              f'a : {len(a)} b : {len(b)} c :{len(c)} d : {len(d)} e : {len(e)} f : {len(f)} final count : {count}')
        ### ner data process ###
        # tokenize data and encoding labels
        tokenized_data = []
        encoded_labels = []
        for i,sentence in enumerate(result_value):
            word, label = label_process(sentence)
            # word 预处理, 对于大于max_length的部分阶段
            if len(word)>max_length:
                word = word[:max_length]
            # TODO 添加句子分割方法,将段句分为每段为512的长度
            # 截断大于512字数的句段,小于512的进行填充
            s = tokenizer.encode_plus(word,return_token_type_ids=True,return_attention_mask=True,return_tensors='pt',
                                             padding='max_length',max_length=max_length)
            tokenized_data.append(s)
            #add label information, 并对label进行编码
            label = label.strip().split(' ')
            # 截断超出512的部分,填充小于512的部分,并编码label
            if len(label)>max_length:
                label = label[:max_length]
            if len(label)<max_length:
                label += ['O'] * (max_length-len(label))
            # encoding label
            l = {k: v for v, k in enumerate(labels)}
            encoded_label = [l[k] for k in label]
            encoded_labels.append(encoded_label)
            if s.input_ids.shape[1]>max_length or s.attention_mask.shape[1]>max_length or s.token_type_ids.shape[1]>max_length:
                print(f'len data:{s.input_ids.shape} {s.attention_mask.shape} {s.token_type_ids.shape} len label:{len(encoded_label)}')
        self.data = tokenized_data
        self.label = encoded_labels
        # TODO add classification data process
        # ...

    def __getitem__(self, index):
        return self.data[index],self.label[index]

    def __len__(self):
        return len(self.data)

###############################
if __name__ == '__main__':
    dataset = MyDataSet()
    token_count = 0
    data_loader = DataLoader(dataset=dataset,shuffle=False,batch_size=10,collate_fn=my_collate)
    for i,data in enumerate(data_loader):
        inputs,labels = data
        print(f'inputs_size:{inputs.shape}\t labels_size:{labels.shape}')
        token_count +=1
    print(f'token_count:{token_count}')

# output = fun4Word(result_value)
# with open('output.txt','w',encoding='utf-8') as file:
#     file.write(output)
#     file.close()

main.py

# -*- encoding : utf-8 -*-
'''
@author : sito
@date : 2022-02-25
@description:
Trying to build model (Bert+BiLSTM+CRF) to solve the problem of Ner,
With low level of code and the persistute of transformers, torch, pytorch-crf
'''
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from dataSet import MyDataSet
from torch.utils.data import DataLoader
from transformers import BertModel
from torchcrf import CRF
import time
import warnings
import logging
import sys
warnings.filterwarnings('ignore')
# log configuration
logger = logging.getLogger('training log')
logger.setLevel(logging.INFO)
# stream handler
rf_handler = logging.StreamHandler(sys.stderr)
rf_handler.setLevel(logging.INFO)
rf_handler.setFormatter(logging.Formatter("%(asctime)s - %(name)s - %(message)s"))
# file handler
f_handler = logging.FileHandler('output/training.log')
f_handler.setLevel(logging.INFO)
f_handler.setFormatter(logging.Formatter("%(asctime)s - %(levelname)s - %(filename)s[:%(lineno)d] - %(message)s"))
logger.addHandler(rf_handler)
logger.addHandler(f_handler)

def my_collate(data):
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    input_ids, attention_mask, token_type_ids, labels = [],[],[],[]
    for i,dat in enumerate(data):
        (input,label) = dat
        input_ids.append(input.input_ids.cpu().squeeze().detach().numpy().tolist())
        attention_mask.append(input.attention_mask.cpu().squeeze().detach().numpy().tolist())
        token_type_ids.append(input.token_type_ids.cpu().squeeze().detach().numpy().tolist())
        labels.append(label)
    return {'input_ids': torch.tensor(input_ids).to(device), 'attention_mask':torch.tensor(attention_mask).to(device),
            'token_type_ids':torch.tensor(token_type_ids).to(device)}, torch.tensor(labels).to(device)

class Model(nn.Module):

    def __init__(self,tag_num):
        super().__init__()
        self.bert = BertModel.from_pretrained('bert-base-chinese')
        config = self.bert.config
        self.lstm = nn.LSTM(bidirectional=True, num_layers=2, input_size=config.hidden_size, hidden_size=config.hidden_size//2, batch_first=True)
        self.crf = CRF(tag_num)
        self.fc = nn.Linear(config.hidden_size,tag_num)

    def forward(self,x,y):
        with torch.no_grad():
            bert_output = self.bert(input_ids=x['input_ids'],attention_mask=x['attention_mask'],token_type_ids=x['token_type_ids'])[0]
        lstm_output, _ = self.lstm(bert_output) # (batch_size,seq_len,hidden_size)
        fc_output = self.fc(lstm_output) # (batch_size,seq_len,tag_num)
        loss = self.crf(fc_output,y) # y (batch_size,seq_len)
        tag = self.crf.decode(fc_output) # (tag_num)
        return loss,tag

if __name__ == '__main__':
    # parameters
    epoches = 50
    max_length = 512
    batch_size = 64 # 32
    lr = 0.0001 # 5e-4  0.0001
    # data preprocess
    dataset = MyDataSet(max_length)
    tag_num = dataset.tag_num
    data_loader = DataLoader(dataset=dataset, shuffle=False, batch_size=batch_size, collate_fn=my_collate)
    # training
    logger.info(f'>>> Training Start!')
    model = Model(tag_num).cuda()
    optimizer = optim.Adam(model.parameters(), lr=lr)
    scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer,T_max=50)
    for e in range(epoches):
        # training
        epoch_end_loss = 0
        model.train()
        for i,data in enumerate(data_loader):
            optimizer.zero_grad()
            inputs, labels = data
            loss,_ = model(inputs,labels)
            loss = abs(loss)
            loss.backward()
            optimizer.step()
            scheduler.step()
            epoch_end_loss = loss
            if i%10==0:
                logger.info(f'>>> epoch {e} <<< step {i} : loss : {loss}')
        logger.info(f'{time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())} epoch {e} training loss : {epoch_end_loss}')
        # evaluating
        if e%10==0 and e!=0:
            model.eval()
            logger.info(f'{time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())} epoch {e} Start Evaluation!')
            step_end_accuracy = []
            with torch.no_grad():
                for i, data in enumerate(data_loader):
                    inputs, labels = data
                    _, tag = model(inputs,labels)
                    tag = np.array(tag).T
                    # calculate the precision
                    for i,(pre_y,real_y) in enumerate(zip(tag,labels)):
                        assert pre_y.shape[0]==real_y.shape[0]==max_length, \
                            f'length not match pre_y.shape[0]:{pre_y.shape[0]} real_y.shape[0]:{real_y.shape[0]}  max_length:{max_length}'
                        sum = pre_y.shape[0]
                        real_y_numpy= real_y.cpu().numpy()
                        cal = pre_y==real_y_numpy
                        count = np.where(cal>0)[0].size
                        accu = count/sum
                        step_end_accuracy.append(accu)
            epoch_end_accuracy = np.mean(step_end_accuracy)
            logger.info(f'{time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())} epoch {e} evaluation accuracy : {epoch_end_accuracy}')
            # save model
            torch.save(model.state_dict(),f'model_p_{epoch_end_accuracy}.pt')

标注好数据集这里面就不方便贴出来了,感兴趣的小伙伴可以私信找我要。

训练的过程也是比较简单了,一共50个epoch,每10个epoch进行一次evaluation,并保存我们的模型,优化器使用了Adam,学利率方面使用了余弦退火的策略,因为一开始学习率需要大一点,越到后面模型学习到的信息越多,就不需要很大的学习率了,小的学习率反而能增加模型的鲁棒性和性能,具体可以看一下albert的论文,里面有很多的训练trick。

欢迎大家一起学习交流,博主对计算机视觉和NLP方向都很感兴趣,以后也会不定时的更新一些好的比较有用的文章,感兴趣的童鞋可以关注我哦~

最后贴一下训练输出的结果吧~
在这里插入图片描述
可以看到在第10个epoch的时候,已经有0.96051的准确率了,bert果然是厉害啊!

  • 5
    点赞
  • 13
    收藏
    觉得还不错? 一键收藏
  • 打赏
    打赏
  • 13
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 13
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Sito_zz

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值