Estimater.predict是tensorflow的高阶API,但是在使用中常常会遇到如下情况: 单次预测一个大文件的速度正常,但是想做成接口来实时预测速度却缓慢:因为每次预测都会重新reload一遍计算图。
那么这个问题是否有解呢?答案:yes。可以在Estimater的层面,实现tensorflow Estimater.predict 的实时预测,将计算图只读取一遍后常驻内存(这里吐槽一下,个人感觉这个功能很重要,毕竟工程化接口必须得做到实时预测,但是google好像对这个方法都没有重点强调…)。
最近在改写BERT模型的时候,从bert-as-service中学习了一波(强烈安利一下https://github.com/hanxiao/bert-as-service,感谢大佬无私的开源)。 此外,该帖子https://groups.google.com/a/tensorflow.org/forum/#!topic/discuss/rOP4VKcfphg 对这个问题也进行了讨论:
最后得到了解答:
那么这里我就简单的解释一下如何解决:
使用: tf.data.Dataset.from_generator 来读取数据,代替 tf.data.Dataset.from_tensor_slices等其他方法。 然后维护一个生成器,来不停的yield一个成员变量,期间保持生成器开启。每当要实时预测时,修改该成员变量为待预测数据,即可直接输出预测结果。
此外,如果不用tf.data.Dataset.from_generator 来读取数据, 那么要么使用feed_dict的方式(不推荐,这样无法使用Estimater),要么自己构造生成器,见该代码:https://github.com/marcsto/rl/blob/master/src/fast_predict2.py
还是直接贴上代码解释一波吧:
用法:直接将你的Esimator用FastPredict类包装起来,或者也按照相同的思路改写
例如: classifier = FastPredict(learn.Estimator(model_fn=model_params.model_fn, model_dir=model_params.model_dir), my_input_fn)
预测时,调用classifier.predict(feature_batch)即可,在classifier生存期间,整个计算图将常驻内存,每次predict无需reload计算图
"""
Speeds up estimator.predict by preventing it from reloading the graph on each call to predict.
加速estimator.predict,防止每次predict时重新加载计算图
It does this by creating a python generator to keep the predict call open.
原理:创建一个python生成器,保持predict进程处于一直开启状态