用深度学习做命名实体识别(五)-模型使用

本文指导如何基于训练好的模型,利用Python的Flask框架创建一个RESTful接口,该接口能从输入的文本中提取人名、地址、组织等实体。详细介绍了项目结构、启动文件、配置文件的创建,以及接口的调用和性能优化。
摘要由CSDN通过智能技术生成

通过本文,你将了解如何基于训练好的模型,来编写一个rest风格的命名实体提取接口,传入一个句子,接口会提取出句子中的人名、地址、组织、公司、产品、时间信息并返回。

核心模块entity_extractor.py
关键函数
# 加载实体识别模型
def person_model_init():
   ...
   
# 预测句子中的实体
def predict(sentence, labels_config, model_dir, batch_size, id2label, label_list, graph, input_ids_p, input_mask_p,
            pred_ids,
            tokenizer,
            sess, max_seq_length):
    ...
完整代码
# -*- coding: utf-8 -*-

"""
基于模型的地址提取
"""
__author__ = '程序员一一涤生'

import codecs
import os
import pickle
from datetime import datetime
from pprint import pprint
import numpy as np
import tensorflow as tf
from bert_base.bert import tokenization, modeling
from bert_base.train.models import create_model, InputFeatures
from bert_base.train.train_helper import get_args_parser

args = get_args_parser()

def convert(line, model_dir, label_list, tokenizer, batch_size, max_seq_length):
    feature = convert_single_example(model_dir, 0, line, label_list, max_seq_length, tokenizer, 'p')
    input_ids = np.reshape([feature.input_ids], (batch_size, max_seq_length))
    input_mask = np.reshape([feature.input_mask], (batch_size, max_seq_length))
    segment_ids = np.reshape([feature.segment_ids], (batch_size, max_seq_length))
    label_ids = np.reshape([feature.label_ids], (batch_size, max_seq_length))
    return input_ids, input_mask, segment_ids, label_ids

def predict(sentence, labels_config, model_dir, batch_size, id2label, label_list, graph, input_ids_p, input_mask_p,
            pred_ids,
            tokenizer,
            sess, max_seq_length):
    with graph.as_default():
        start = datetime.now()
        # print(id2label)
        sentence = tokenizer.tokenize(sentence)
        # print('your input is:{}'.format(sentence))
        input_ids, input_mask, segment_ids, label_ids = convert(sentence, model_dir, label_list, tokenizer, batch_size,
                                                                max_seq_length)

        feed_dict = {input_ids_p: input_ids,
                     input_mask_p: input_mask}
        # run session get current feed_dict result
        pred_ids_result = sess.run([pred_ids], feed_dict)
        pred_label_result = convert_id_to_label(pred_ids_result, id2label, batch_size)
        # print(pred_ids_result)
        print(pred_label_result)
        # todo: 组合策略
        result = strage_combined(sentence, pred_label_result[0], labels_config)
        print('time used: {} sec'.format((datetime.now() - start).total_seconds()))
    return result, pred_label_result

def convert_id_to_label(pred_ids_result, idx2label, batch_size):
    """
    将id形式的结果转化为真实序列结果
    :param pred_ids_result:
    :param idx2label:
    :return:
    """
    result = []
    for row in range(batch_size):
        curr_seq = []
        for ids in pred_ids_result[row][0]:
            if ids == 0:
                break
            curr_label = idx2label[ids]
            if curr_label in ['[CLS]', '[SEP]']:
                continue
            curr_seq.append(curr_label)
        result.append(curr_seq)
    return result

def strage_combined(tokens, tags, labels_config):
    """
    组合策略
    :param pred_label_result:
    :param types:
    :return:
    """
    def get_output(rs, data, type):
        words = []
        for i in data:
            words.append(str(i.word).replace("#", ""))
            # words.append(i.word)
        rs[type] = words
        return rs
    eval = Result(labels_config)
    if len(tokens) > len(tags):
        tokens = tokens[:len(tags)]
    labels_dict = eval.get_result(tokens, tags)
    arr = []
    for k, v in labels_dict.items():
        arr.append((k, v))
    rs = {}
    for item in arr:
        rs = get_output(rs, item[1], item[0])
    return rs

def convert_single_example(model_dir, ex_index, example, label_list, max_seq_length, tokenizer, mode):
    """
    将一个样本进行分析,然后将字转化为id, 标签转化为id,然后结构化到InputFeatures对象中
    :param ex_index: index
    :param example: 一个样本
    :param label_list: 标签列表
    :param max_seq_length:
    :param tokenizer:
    :param mode:
    :return:
    """
    label_map = {}
    # 1表示从1开始对label进行index化
    for (i, label) in enumerate(label_list, 1):
        label_map[label] = i
    # 保存label->index 的map
    if not os.path.exists(os.path.join(model_dir, 'label2id.pkl')):
        with codecs.open(os.path.join(model_dir, 'label2id.pkl'), 'wb') as w:
            pickle.dump(label_map, w)
    tokens = example
    # tokens = tokenizer.tokenize(example.text)
    # 序列截断
    if len(tokens) >= max_seq_length - 1:
        tokens = tokens[0:(max_seq_length - 2)]  # -2 的原因是因为序列需要加一个句首和句尾标志
    ntokens = 
  • 1
    点赞
  • 11
    收藏
    觉得还不错? 一键收藏
  • 0
    评论
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值