本文将会介绍如何使用tensorflow来同时加载多个模型,其中生成的模型文件为ckpt格式。
本文以Github上的bertNER项目为模板来介绍。项目结构如下:
我们已经生成了三个模型:war,weapon,geo,每个模型都有各自对应的ckpt、config、maps、以及log文件。三个模型的模型参数都一样,最大长度为128。
加载一个模型的服务
由于bertNER项目已经为我们提供好了预测脚本predict.py,因此我们很方便地就能将它改成为HTTP服务的方式,代码如下:
# -*- coding: utf-8 -*-
import json
import pickle
import tornado.httpserver
import tornado.ioloop
import tornado.options
import tornado.web
from tornado.options import define, options
import tensorflow as tf
from utils import create_model, get_logger, load_config
from model import Model
from loader import input_from_line
define("port", default=5005, help="run on the given port", type=int)
with open("geo_maps.pkl", "rb") as f:
geo_tag_to_id, geo_id_to_tag = pickle.load(f)
# 加载模型
sess = tf.Session()
geo_model = create_model(sess, Model, "geo_ckpt", load_config("geo_config_file"), get_logger("geo_train.log"))
# Geo模型预测
class GeoHandler(tornado.web.RequestHandler):
# post函数
def post(self):
event = self.get_argument('event')
result = geo_model.evaluate_line(sess, input_from_line(event, 128, geo_tag_to_id), geo_id_to_tag)
self.write(json.dumps(result, ensure_ascii=False))
# 主函数
def main():
# 开启tornado服务
tornado.options.parse_command_line()
# 定义app
app = tornado.web.Application(
handlers=[(r'/model/entity/geo', GeoHandler)],
)
http_server = tornado.httpserver.HTTPServer(app)
http_server.listen(options.port)
tornado.ioloop.IOLoop.instance().start()
main()
启动该脚本后,最后一行为:
2020-10-28 20:35:39,578 - geo_train.log - INFO - Reading model parameters from geo_ckpt\ner.ckpt-8836
即表明geo模型已经加载成功,该HTTP服务也可以成功调用。
同时加载多个模型的服务
下面将介绍如何一次性加载三个模型,并能通过HTTP服务调用成功。
在Tensorflow中,所有操作对象都包装到相应的Session中的,所以想要使用不同的模型就需要将这些模型加载到不同的Session中并在使用的时候申明是哪个Session,从而避免由于Session和想使用的模型不匹配导致的错误。而使用多个graph,就需要为每个graph使用不同的Session,但是每个graph也可以在多个Session中使用,这个时候就需要在每个Session使用的时候明确申明使用的graph。
在不少文章中都给出了使用样例代码,笔者将给出一次性加载上述三个模型(geo、war、weapon)的代码,如下:
# -*- coding: utf-8 -*-
import json
import pickle
import tornado.httpserver
import tornado.ioloop
import tornado.options
import tornado.web
from tornado.options import define, options
import tensorflow as tf
from utils import create_model, get_logger, load_config
from model import Model
from loader import input_from_line
define("port", default=5005, help="run on the given port", type=int)
with open("geo_maps.pkl", "rb") as f:
geo_tag_to_id, geo_id_to_tag = pickle.load(f)
with open("geo_maps.pkl", "rb") as f:
war_tag_to_id, war_id_to_tag = pickle.load(f)
with open("weapon_maps.pkl", "rb") as f:
weapon_tag_to_id, weapon_id_to_tag = pickle.load(f)
g1 = tf.Graph()
g2 = tf.Graph()
g3 = tf.Graph()
sess1 = tf.Session(graph=g1)
sess2 = tf.Session(graph=g2)
sess3 = tf.Session(graph=g3)
# 加载第一个模型
with sess1.as_default():
with g1.as_default():
geo_model = create_model(sess1, Model, "geo_ckpt", load_config("geo_config_file"), get_logger("geo_train.log"))
with sess2.as_default():
with g2.as_default():
war_model = create_model(sess2, Model, "war_ckpt", load_config("war_config_file"), get_logger("war_train.log"))
with sess3.as_default():
with g3.as_default():
weapon_model = create_model(sess3, Model, "weapon_ckpt", load_config("weapon_config_file"), get_logger("weapon_train.log"))
# Geo模型预测
class GeoHandler(tornado.web.RequestHandler):
# post函数
def post(self):
event = self.get_argument('event')
with sess1.as_default():
with sess1.graph.as_default():
result = geo_model.evaluate_line(sess1, input_from_line(event, 128, geo_tag_to_id), geo_id_to_tag)
self.write(json.dumps(result, ensure_ascii=False))
# War模型预测
class WarHandler(tornado.web.RequestHandler):
# post函数
def post(self):
event = self.get_argument('event')
with sess2.as_default():
with sess2.graph.as_default():
result = war_model.evaluate_line(sess2, input_from_line(event, 128, war_tag_to_id), war_id_to_tag)
self.write(json.dumps(result, ensure_ascii=False))
# Weapon模型预测
class WeaponHandler(tornado.web.RequestHandler):
# post函数
def post(self):
event = self.get_argument('event')
with sess3.as_default():
with sess3.graph.as_default():
result = weapon_model.evaluate_line(sess3, input_from_line(event, 128, weapon_tag_to_id), weapon_id_to_tag)
self.write(json.dumps(result, ensure_ascii=False))
# 主函数
def main():
# 开启tornado服务
tornado.options.parse_command_line()
# 定义app
app = tornado.web.Application(
handlers=[(r'/model/entity/geo', GeoHandler),
(r'/model/entity/war', WarHandler),
(r'/model/entity/weapon', WeaponHandler)],
)
http_server = tornado.httpserver.HTTPServer(app)
http_server.listen(options.port)
tornado.ioloop.IOLoop.instance().start()
main()
在启动的输出中,我们看到有输出:
2020-10-28 20:41:08,149 - geo_train.log - INFO - Reading model parameters from geo_ckpt\ner.ckpt-8836
2020-10-28 20:41:14,846 - war_train.log - INFO - Reading model parameters from war_ckpt\ner.ckpt-23014
2020-10-28 20:41:22,262 - weapon_train.log - INFO - Reading model parameters from weapon_ckpt\ner.ckpt-28756
这表明该HTTP服务已经成功加载三个模型,并能调用成功。
总结
本来通过在Tensorflow中创建多个Session,每个Session运行一个graph,在运行时制定默认Session的办法来实现同时加载多个模型,并封装成HTTP服务。
当然Tensorflow同时加载多个模型的办法不止这一种,后面的文章将会介绍如何使用tensorflow/serving来同时加载多个模型。
感谢大家的阅读~
参考文献
- bertNER项目:https://github.com/yumath/bertNER
- Tensorflow同时加载使用多个模型:https://www.cnblogs.com/arkenstone/p/7016481.html