Tensorflow 利用高阶API Estimater.predict 实现实时预测,避免reload计算图

本文探讨了在Tensorflow中使用Estimater.predict进行实时预测时遇到的计算图重载问题,并提供了解决方案。通过使用tf.data.Dataset.from_generator读取数据,保持生成器开启,可以实现在不重新加载计算图的情况下进行高效预测。文章还引用了相关资源,包括BERT-as-service项目和Google论坛讨论,并提供了FastPredict类的示例代码,以帮助读者理解并应用这一优化策略。
摘要由CSDN通过智能技术生成

Estimater.predict是tensorflow的高阶API,但是在使用中常常会遇到如下情况: 单次预测一个大文件的速度正常,但是想做成接口来实时预测速度却缓慢:因为每次预测都会重新reload一遍计算图。

那么这个问题是否有解呢?答案:yes。可以在Estimater的层面,实现tensorflow Estimater.predict 的实时预测,将计算图只读取一遍后常驻内存(这里吐槽一下,个人感觉这个功能很重要,毕竟工程化接口必须得做到实时预测,但是google好像对这个方法都没有重点强调…)。

最近在改写BERT模型的时候,从bert-as-service中学习了一波(强烈安利一下https://github.com/hanxiao/bert-as-service,感谢大佬无私的开源)。 此外,该帖子https://groups.google.com/a/tensorflow.org/forum/#!topic/discuss/rOP4VKcfphg 对这个问题也进行了讨论:

在这里插入图片描述
最后得到了解答:
在这里插入图片描述

那么这里我就简单的解释一下如何解决:
使用: tf.data.Dataset.from_generator 来读取数据,代替 tf.data.Dataset.from_tensor_slices等其他方法。 然后维护一个生成器,来不停的yield一个成员变量,期间保持生成器开启。每当要实时预测时,修改该成员变量为待预测数据,即可直接输出预测结果。

此外,如果不用tf.data.Dataset.from_generator 来读取数据, 那么要么使用feed_dict的方式(不推荐,这样无法使用Estimater),要么自己构造生成器,见该代码:https://github.com/marcsto/rl/blob/master/src/fast_predict2.py


还是直接贴上代码解释一波吧:
用法:直接将你的Esimator用FastPredict类包装起来,或者也按照相同的思路改写
例如: classifier = FastPredict(learn.Estimator(model_fn=model_params.model_fn, model_dir=model_params.model_dir), my_input_fn)
预测时,调用classifier.predict(feature_batch)即可,在classifier生存期间,整个计算图将常驻内存,每次predict无需reload计算图

"""
    Speeds up estimator.predict by preventing it from reloading the graph on each call to predict.
    加速estimator.predict,防止每次predict时重新加载计算图
    It does this by creating a python generator to keep the predict call open.
    原理:创建一个python生成器,保持predict进程处于一直开启状态
 
  • 12
    点赞
  • 26
    收藏
    觉得还不错? 一键收藏
  • 29
    评论
评论 29
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值