TensorFlow关于怎样解决Estimater.predict总是重新加载模型的问题

思路来自: https://blog.csdn.net/qq547276542/article/details/85080139#commentsedit
在此表示感谢

问题:

大家用Estimater.predict总是把模型重新load一遍,这样工程业务根本没法用。

解决方案:

  1. 使用python的生成器,让程序“误以为”有很多序列需要预测,这里构造yield形式即可;
  2. 利用tf.data.Dataset.from_generator,加载生成器,声明好数据结构和类型;
  3. 利用class类的实例变量self的全局性,通过self.inputs把数据“喂给到”生成器内部,这样就保证了数据的“源源不断”;
  4. 程序需要close的机制,用于保证停掉生成器的工作。

代码

我的代码是用于Bert模型的,思路说清楚了,具体功能请自行修改。

from tokenization import FullTokenizer, validate_case_matches_checkpoint
from conv_example import convert_single_example
from process import InputExample
from modeling import BertConfig
from model_func import model_fn_builder
from config import FLAGS, TFConfig
import tensorflow as tf
from tensorflow.python.estimator.estimator import Estimator
from tensorflow.python.estimator.run_config import RunConfig


class Fast(object):
    def __init__(self, label):
        self.label = label
        self.closed = False
        self.first_run = True
        self.tokenizer = FullTokenizer(
            vocab_file=FLAGS.vocab_file,
            do_lower_case=True)
        self.init_checkpoint = FLAGS.init_checkpoint
        self.seq_length = FLAGS.max_seq_length
        self.text = None
        self.num_examples = None
        self.predictions = None
        self.estimator = self.get_estimator()

    def get_estimator(self):
        validate_case_matches_checkpoint(True, self.init_checkpoint)
        bert_config = BertConfig.from_json_file(FLAGS.bert_config_file)  # 载入bert自定义配置
        if FLAGS.max_seq_length > bert_config.max_position_embeddings:  # 验证配置信息准确性
            raise ValueError(
                "Cannot use sequence length %d because the BERT pre_model "
                "was only trained up to sequence length %d" %
                (self.seq_length, bert_config.max_position_embeddings))

        run_config = RunConfig(
            model_dir=FLAGS.output_dir,
            save_checkpoints_steps=FLAGS.save_checkpoints_steps,
            session_config=TFConfig.cpu()
        )
        model_fn = model_fn_builder(  # 估计器函数,提供Estimator使用的model_fn,内部使用EstimatorSpec构建的
            bert_config=bert_config,
            num_labels=len(self.label),
            init_checkpoint=self.init_checkpoint,
            learning_rate=FLAGS.learning_rate,
            num_train_steps=None,
            num_warmup_steps=None,
            use_tpu=FLAGS.use_tpu,
            use_one_hot_embeddings=FLAGS.use_tpu)

        estimator = Estimator(  # 实例化估计器
            model_fn=model_fn,
            config=run_config,
            warm_start_from=self.init_checkpoint  # 新增预热
        )
        return estimator

    def get_feature(self, index, text):
        example = InputExample(f"text_{index}", text, None, self.label[0])
        feature = convert_single_example(index, example, self.label, self.seq_length, self.tokenizer)
        return feature.input_ids, feature.input_mask, feature.segment_ids, feature.label_id

    def create_generator(self):
        """构建生成器"""
        while not self.closed:
            self.num_examples = len(self.text)
            features = (self.get_feature(*f) for f in enumerate(self.text))
            yield dict(zip(("input_ids", "input_mask", "segment_ids", "label_ids"), zip(*features)))

    def input_fn_builder(self):
        """用于预测单独对预测数据进行创建,不基于文件数据"""
        dataset = tf.data.Dataset.from_generator(
            self.create_generator,
            output_types={'input_ids': tf.int32,
                          'input_mask': tf.int32,
                          'segment_ids': tf.int32,
                          'label_ids': tf.int32},
            output_shapes={
                'label_ids': (None),
                'input_ids': (None, None),
                'input_mask': (None, None),
                'segment_ids': (None, None)}
        )
        return dataset

    def predict(self, text):
        self.text = text
        if self.first_run:
            self.predictions = self.estimator.predict(
                input_fn=self.input_fn_builder, yield_single_examples=False)
            self.first_run = False
        probabilities = next(self.predictions)
        return [self.label[i] for i in probabilities["probabilities"].argmax(axis=1)]

    def close(self):
        self.closed = True
评论 60 您还未登录,请先 登录 后发表或查看评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
©️2022 CSDN 皮肤主题:数字50 设计师:CSDN官方博客 返回首页

打赏作者

小小逐月者

你的鼓励将是我创作的最大动力

¥2 ¥4 ¥6 ¥10 ¥20
输入1-500的整数
余额支付 (余额:-- )
扫码支付
扫码支付:¥2
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、C币套餐、付费专栏及课程。

余额充值