1.在使用BERT做文本分类时,在使用tf.estimator.Estimator做预测时,速度总是很慢,单纯的一个生成器解析都很耗费时间;无论是使用ckpt和pb预测总是会变成对ckpt模型的预测,因此,本文重写了pb的预测代码:
BERT模型训练完以后得到的文件有:
pb文件夹中的classification_model.pb就是预测使用的pb格式模型;
2.首先读取pb模型:
with tf.gfile.GFile(self.model_path, "rb") as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
with tf.Graph().as_default() as graph:
tf.import_graph_def(graph_def, name='')
# graph:可以返回该变量
在查看graph以后知道模型的输入输出的名字:
input_ids
input_mask
Shape
strided_slice/stack
...
...
...
loss/BiasAdd
loss/Softmax
pred_prob
# 中间节点省略
可以看出输入有两个:input_ids、input_mask;
输出为:pred_prob
3.使用graph.get_tensor_by_name()方法来获得已经保存的操作(operations)和placeholder variables以及创建一个session;
input_ids = graph.get_tensor_by_name('input_ids:0')
input_mask = graph.get_tensor_by_name('input_mask:0')
pred_prob = graph.get_tensor_by_name('pred_prob:0')
sess = tf.Session(graph=graph)
4.数据方面的初始化和训练时的初始化大同小异;都是从一个text到features;
features包含:
input_ids
input_mask
segment_ids
label_id
is_real_example
5.后面就是使用前面得到的session和数据做预测,代码很简单:
preds_evaluated = sess.run([pred_prob], feed_dict={input_ids: [features.input_ids], input_mask: [features.input_mask]})
pred = preds_evaluated[0]
pred_index = pred.argmax(axis=1)[0]
pred_score = pred[0][pred_index]
得到的pred_index就是类别的id,pred_score就是分数;
最主要的是pb模型的节点,知道节点就可以接下来的一系列步骤。
6.完整测试代码
import time
import cv2
import tensorflow as tf
import tokenization
import csv
class InputExample(object):
def __init__(self, guid, text_a, text_b=None, label=None):
self.guid = guid
self.text_a = text_a
self.text_b = text_b
self.label = label
class InputFeatures(object):
def __init__(self,
input_ids, input_mask, segment_ids, label_id, is_real_example=True):
self.input_ids = input_ids
self.input_mask = input_mask
self.segment_ids = segment_ids
self.label_id = label_id
self.is_real_example = is_real_example
class PredictDetect(object):
def __init__(self):
self.label_list = ['class1', 'class2', 'class3']
self.model_path = './pb/classification_model.pb'
self.detection_graph = self._load_model()
self.sess = tf.Session(graph=self.detection_graph)
self.input_ids = self.detection_graph.get_tensor_by_name('input_ids:0')
self.input_mask = self.detection_graph.get_tensor_by_name('input_mask:0')
self.pred_prob = self.detection_graph.get_tensor_by_name('pred_prob:0')
self.input_map = {"input_ids": self.input_ids, "input_mask": self.input_mask}
self.tokenizer = tokenization.FullTokenizer(vocab_file='./pb/vocab.txt', do_lower_case=True)
def _load_model(self):
with tf.gfile.GFile(self.model_path, "rb") as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
with tf.Graph().as_default() as graph:
tf.import_graph_def(graph_def, name='')
for op in graph.get_operations():
print(op.name)
return graph
def _data_init(self, sentence):
_exam = self.one_example(sentence)
_feature = self.convert_single_example(0, _exam, self.label_list, 150, self.tokenizer)
return _feature
def detect(self, sentence):
features = self._data_init(sentence)
preds_evaluated = self.sess.run([self.pred_prob], feed_dict={self.input_ids: [features.input_ids],
self.input_mask: [features.input_mask]})
pred = preds_evaluated[0]
pred_index = pred.argmax(axis=1)[0]
pred_score = pred[0][pred_index]
print(pred_index)
return self.label_list[pred_index], pred_score
@staticmethod
def one_example(sentence):
guid, label = 'pred-0', '#'
text_a, text_b = sentence, None
return InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)
@staticmethod
def _truncate_seq_pair(tokens_a, tokens_b, max_length):
while True:
total_length = len(tokens_a) + len(tokens_b)
if total_length <= max_length:
break
if len(tokens_a) > len(tokens_b):
tokens_a.pop()
else:
tokens_b.pop()
def convert_single_example(self, ex_index, example, label_list_, max_seq_length, tokenizer):
"""Converts a single `InputExample` into a single `InputFeatures`."""
label_map = {}
for (i, label) in enumerate(label_list_):
label_map[label] = i
tokens_a = tokenizer.tokenize(example.text_a)
tokens_b = None
if example.text_b:
tokens_b = tokenizer.tokenize(example.text_b)
if tokens_b:
self._truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 3)
else:
if len(tokens_a) > max_seq_length - 2:
tokens_a = tokens_a[0:(max_seq_length - 2)]
tokens = []
segment_ids = []
tokens.append("[CLS]")
segment_ids.append(0)
for token in tokens_a:
tokens.append(token)
segment_ids.append(0)
tokens.append("[SEP]")
segment_ids.append(0)
if tokens_b:
for token in tokens_b:
tokens.append(token)
segment_ids.append(1)
tokens.append("[SEP]")
segment_ids.append(1)
input_ids = tokenizer.convert_tokens_to_ids(tokens)
# The mask has 1 for real tokens and 0 for padding tokens. Only real
# tokens are attended to.
input_mask = [1] * len(input_ids)
# Zero-pad up to the sequence length.
while len(input_ids) < max_seq_length:
input_ids.append(0)
input_mask.append(0)
segment_ids.append(0)
assert len(input_ids) == max_seq_length
assert len(input_mask) == max_seq_length
assert len(segment_ids) == max_seq_length
label_id = label_map[example.label]
if ex_index < 5:
tf.logging.info("*** Example ***")
tf.logging.info("guid: %s" % example.guid)
tf.logging.info("tokens: %s" % " ".join(
[tokenization.printable_text(x) for x in tokens]))
tf.logging.info("input_ids: %s" % " ".join([str(x) for x in input_ids]))
tf.logging.info("input_mask: %s" % " ".join([str(x) for x in input_mask]))
tf.logging.info("segment_ids: %s" % " ".join([str(x) for x in segment_ids]))
tf.logging.info("label: %s (id = %d)" % (example.label, label_id))
feature_ = InputFeatures(
input_ids=input_ids, input_mask=input_mask,
segment_ids=segment_ids, label_id=label_id,
is_real_example=True)
return feature_
predict_detect = PredictDetect()
try:
data_row = '测试数据1....'
start = time.time()
index, score = predict_detect.detect(data_row)
print(time.time() - start)
print(index)
print(score)
except Exception as e:
print(e)