BERT TextClassify Estimator预测速度慢的问题

1.在使用BERT做文本分类时,在使用tf.estimator.Estimator做预测时,速度总是很慢,单纯的一个生成器解析都很耗费时间;无论是使用ckpt和pb预测总是会变成对ckpt模型的预测,因此,本文重写了pb的预测代码:

BERT模型训练完以后得到的文件有:

pb文件夹中的classification_model.pb就是预测使用的pb格式模型;

2.首先读取pb模型:

with tf.gfile.GFile(self.model_path, "rb") as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())

with tf.Graph().as_default() as graph:
    tf.import_graph_def(graph_def, name='')

# graph:可以返回该变量

在查看graph以后知道模型的输入输出的名字:

input_ids
input_mask
Shape
strided_slice/stack
...
...
...
loss/BiasAdd
loss/Softmax
pred_prob

# 中间节点省略

可以看出输入有两个:input_ids、input_mask;

输出为:pred_prob

 

3.使用graph.get_tensor_by_name()方法来获得已经保存的操作(operations)和placeholder variables以及创建一个session;

input_ids = graph.get_tensor_by_name('input_ids:0')
input_mask = graph.get_tensor_by_name('input_mask:0')
pred_prob = graph.get_tensor_by_name('pred_prob:0')

sess = tf.Session(graph=graph)

4.数据方面的初始化和训练时的初始化大同小异;都是从一个text到features;

features包含:

input_ids
input_mask
segment_ids
label_id
is_real_example

5.后面就是使用前面得到的session和数据做预测,代码很简单:

preds_evaluated = sess.run([pred_prob], feed_dict={input_ids: [features.input_ids], input_mask: [features.input_mask]})

pred = preds_evaluated[0]
pred_index = pred.argmax(axis=1)[0]
pred_score = pred[0][pred_index]

得到的pred_index就是类别的id,pred_score就是分数;

 

最主要的是pb模型的节点,知道节点就可以接下来的一系列步骤。

 

6.完整测试代码

import time

import cv2
import tensorflow as tf

import tokenization
import csv


class InputExample(object):

    def __init__(self, guid, text_a, text_b=None, label=None):
        self.guid = guid
        self.text_a = text_a
        self.text_b = text_b
        self.label = label


class InputFeatures(object):
    def __init__(self,
                 input_ids, input_mask, segment_ids, label_id, is_real_example=True):
        self.input_ids = input_ids
        self.input_mask = input_mask
        self.segment_ids = segment_ids
        self.label_id = label_id
        self.is_real_example = is_real_example


class PredictDetect(object):
    def __init__(self):
        self.label_list = ['class1', 'class2', 'class3']

        self.model_path = './pb/classification_model.pb'
        self.detection_graph = self._load_model()

        self.sess = tf.Session(graph=self.detection_graph)

        self.input_ids = self.detection_graph.get_tensor_by_name('input_ids:0')
        self.input_mask = self.detection_graph.get_tensor_by_name('input_mask:0')
        self.pred_prob = self.detection_graph.get_tensor_by_name('pred_prob:0')

        self.input_map = {"input_ids": self.input_ids, "input_mask": self.input_mask}

        self.tokenizer = tokenization.FullTokenizer(vocab_file='./pb/vocab.txt', do_lower_case=True)

    def _load_model(self):
        with tf.gfile.GFile(self.model_path, "rb") as f:
            graph_def = tf.GraphDef()
            graph_def.ParseFromString(f.read())

        with tf.Graph().as_default() as graph:
            tf.import_graph_def(graph_def, name='')

        for op in graph.get_operations():
            print(op.name)

        return graph

    def _data_init(self, sentence):
        _exam = self.one_example(sentence)
        _feature = self.convert_single_example(0, _exam, self.label_list, 150, self.tokenizer)

        return _feature

    def detect(self, sentence):
        features = self._data_init(sentence)

        preds_evaluated = self.sess.run([self.pred_prob], feed_dict={self.input_ids: [features.input_ids],
                                                                     self.input_mask: [features.input_mask]})

        pred = preds_evaluated[0]
        pred_index = pred.argmax(axis=1)[0]
        pred_score = pred[0][pred_index]
        print(pred_index)

        return self.label_list[pred_index], pred_score

    @staticmethod
    def one_example(sentence):
        guid, label = 'pred-0', '#'
        text_a, text_b = sentence, None
        return InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label)

    @staticmethod
    def _truncate_seq_pair(tokens_a, tokens_b, max_length):
        while True:
            total_length = len(tokens_a) + len(tokens_b)
            if total_length <= max_length:
                break
            if len(tokens_a) > len(tokens_b):
                tokens_a.pop()
            else:
                tokens_b.pop()

    def convert_single_example(self, ex_index, example, label_list_, max_seq_length, tokenizer):
        """Converts a single `InputExample` into a single `InputFeatures`."""

        label_map = {}
        for (i, label) in enumerate(label_list_):
            label_map[label] = i

        tokens_a = tokenizer.tokenize(example.text_a)
        tokens_b = None
        if example.text_b:
            tokens_b = tokenizer.tokenize(example.text_b)

        if tokens_b:
            self._truncate_seq_pair(tokens_a, tokens_b, max_seq_length - 3)
        else:
            if len(tokens_a) > max_seq_length - 2:
                tokens_a = tokens_a[0:(max_seq_length - 2)]

        tokens = []
        segment_ids = []
        tokens.append("[CLS]")
        segment_ids.append(0)
        for token in tokens_a:
            tokens.append(token)
            segment_ids.append(0)
        tokens.append("[SEP]")
        segment_ids.append(0)

        if tokens_b:
            for token in tokens_b:
                tokens.append(token)
                segment_ids.append(1)
            tokens.append("[SEP]")
            segment_ids.append(1)

        input_ids = tokenizer.convert_tokens_to_ids(tokens)

        # The mask has 1 for real tokens and 0 for padding tokens. Only real
        # tokens are attended to.
        input_mask = [1] * len(input_ids)

        # Zero-pad up to the sequence length.
        while len(input_ids) < max_seq_length:
            input_ids.append(0)
            input_mask.append(0)
            segment_ids.append(0)

        assert len(input_ids) == max_seq_length
        assert len(input_mask) == max_seq_length
        assert len(segment_ids) == max_seq_length

        label_id = label_map[example.label]
        if ex_index < 5:
            tf.logging.info("*** Example ***")
            tf.logging.info("guid: %s" % example.guid)
            tf.logging.info("tokens: %s" % " ".join(
                [tokenization.printable_text(x) for x in tokens]))
            tf.logging.info("input_ids: %s" % " ".join([str(x) for x in input_ids]))
            tf.logging.info("input_mask: %s" % " ".join([str(x) for x in input_mask]))
            tf.logging.info("segment_ids: %s" % " ".join([str(x) for x in segment_ids]))
            tf.logging.info("label: %s (id = %d)" % (example.label, label_id))

        feature_ = InputFeatures(
            input_ids=input_ids, input_mask=input_mask,
            segment_ids=segment_ids, label_id=label_id,
            is_real_example=True)
        return feature_


predict_detect = PredictDetect()

try:
    data_row = '测试数据1....'

    start = time.time()
    index, score = predict_detect.detect(data_row)
    print(time.time() - start)

    print(index)
    print(score)
except Exception as e:
    print(e)

 

  • 3
    点赞
  • 4
    收藏
    觉得还不错? 一键收藏
  • 3
    评论
评论 3
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值