通过本文,你将了解如何基于训练好的模型,来编写一个rest风格的命名实体提取接口,传入一个句子,接口会提取出句子中的人名、地址、组织、公司、产品、时间信息并返回。
核心模块entity_extractor.py
关键函数
# 加载实体识别模型
def person_model_init():
...
# 预测句子中的实体
def predict(sentence, labels_config, model_dir, batch_size, id2label, label_list, graph, input_ids_p, input_mask_p,
pred_ids,
tokenizer,
sess, max_seq_length):
...
完整代码
# -*- coding: utf-8 -*-
"""
基于模型的地址提取
"""
__author__ = '程序员一一涤生'
import codecs
import os
import pickle
from datetime import datetime
from pprint import pprint
import numpy as np
import tensorflow as tf
from bert_base.bert import tokenization, modeling
from bert_base.train.models import create_model, InputFeatures
from bert_base.train.train_helper import get_args_parser
args = get_args_parser()
def convert(line, model_dir, label_list, tokenizer, batch_size, max_seq_length):
feature = convert_single_example(model_dir, 0, line, label_list, max_seq_length, tokenizer, 'p')
input_ids = np.reshape([feature.input_ids], (batch_size, max_seq_length))
input_mask = np.reshape([feature.input_mask], (batch_size, max_seq_length))
segment_ids = np.reshape([feature.segment_ids], (batch_size, max_seq_length))
label_ids = np.reshape([feature.label_ids], (batch_size, max_seq_length))
return input_ids, input_mask, segment_ids, label_ids
def predict(sentence, labels_config, model_dir, batch_size, id2label, label_list, graph, input_ids_p, input_mask_p,
pred_ids,
tokenizer,
sess, max_seq_length):
with graph.as_default():
start = datetime.now()
# print(id2label)
sentence = tokenizer.tokenize(sentence)
# print('your input is:{}'.format(sentence))
input_ids, input_mask, segment_ids, label_ids = convert(sentence, model_dir, label_list, tokenizer, batch_size,
max_seq_length)
feed_dict = {input_ids_p: input_ids,
input_mask_p: input_mask}
# run session get current feed_dict result
pred_ids_result = sess.run([pred_ids], feed_dict)
pred_label_result = convert_id_to_label(pred_ids_result, id2label, batch_size)
# print(pred_ids_result)
print(pred_label_result)
# todo: 组合策略
result = strage_combined(sentence, pred_label_result[0], labels_config)
print('time used: {} sec'.format((datetime.now() - start).total_seconds()))
return result, pred_label_result
def convert_id_to_label(pred_ids_result, idx2label, batch_size):
"""
将id形式的结果转化为真实序列结果
:param pred_ids_result:
:param idx2label:
:return:
"""
result = []
for row in range(batch_size):
curr_seq = []
for ids in pred_ids_result[row][0]:
if ids == 0:
break
curr_label = idx2label[ids]
if curr_label in ['[CLS]', '[SEP]']:
continue
curr_seq.append(curr_label)
result.append(curr_seq)
return result
def strage_combined(tokens, tags, labels_config):
"""
组合策略
:param pred_label_result:
:param types:
:return:
"""
def get_output(rs, data, type):
words = []
for i in data:
words.append(str(i.word).replace("#", ""))
# words.append(i.word)
rs[type] = words
return rs
eval = Result(labels_config)
if len(tokens) > len(tags):
tokens = tokens[:len(tags)]
labels_dict = eval.get_result(tokens, tags)
arr = []
for k, v in labels_dict.items():
arr.append((k, v))
rs = {}
for item in arr:
rs = get_output(rs, item[1], item[0])
return rs
def convert_single_example(model_dir, ex_index, example, label_list, max_seq_length, tokenizer, mode):
"""
将一个样本进行分析,然后将字转化为id, 标签转化为id,然后结构化到InputFeatures对象中
:param ex_index: index
:param example: 一个样本
:param label_list: 标签列表
:param max_seq_length:
:param tokenizer:
:param mode:
:return:
"""
label_map = {}
# 1表示从1开始对label进行index化
for (i, label) in enumerate(label_list, 1):
label_map[label] = i
# 保存label->index 的map
if not os.path.exists(os.path.join(model_dir, 'label2id.pkl')):
with codecs.open(os.path.join(model_dir, 'label2id.pkl'), 'wb') as w:
pickle.dump(label_map, w)
tokens = example
# tokens = tokenizer.tokenize(example.text)
# 序列截断
if len(tokens) >= max_seq_length - 1:
tokens = tokens[0:(max_seq_length - 2)] # -2 的原因是因为序列需要加一个句首和句尾标志
ntokens =