tensorflow(3)同时加载多个模型的一次尝试

  本文将会介绍如何使用tensorflow来同时加载多个模型,其中生成的模型文件为ckpt格式。
  本文以Github上的bertNER项目为模板来介绍。项目结构如下:
项目结构代码
我们已经生成了三个模型:war,weapon,geo,每个模型都有各自对应的ckpt、config、maps、以及log文件。三个模型的模型参数都一样,最大长度为128。

加载一个模型的服务

  由于bertNER项目已经为我们提供好了预测脚本predict.py,因此我们很方便地就能将它改成为HTTP服务的方式,代码如下:

# -*- coding: utf-8 -*-
import json
import pickle

import tornado.httpserver
import tornado.ioloop
import tornado.options
import tornado.web
from tornado.options import define, options

import tensorflow as tf
from utils import create_model, get_logger, load_config
from model import Model
from loader import input_from_line

define("port", default=5005, help="run on the given port", type=int)

with open("geo_maps.pkl", "rb") as f:
    geo_tag_to_id, geo_id_to_tag = pickle.load(f)

# 加载模型
sess = tf.Session()
geo_model = create_model(sess, Model, "geo_ckpt", load_config("geo_config_file"), get_logger("geo_train.log"))


# Geo模型预测
class GeoHandler(tornado.web.RequestHandler):
    # post函数
    def post(self):
        event = self.get_argument('event')
        result = geo_model.evaluate_line(sess, input_from_line(event, 128, geo_tag_to_id), geo_id_to_tag)
        self.write(json.dumps(result, ensure_ascii=False))


# 主函数
def main():
    # 开启tornado服务
    tornado.options.parse_command_line()
    # 定义app
    app = tornado.web.Application(
            handlers=[(r'/model/entity/geo', GeoHandler)],
           )
    http_server = tornado.httpserver.HTTPServer(app)
    http_server.listen(options.port)
    tornado.ioloop.IOLoop.instance().start()


main()

启动该脚本后,最后一行为:

2020-10-28 20:35:39,578 - geo_train.log - INFO - Reading model parameters from geo_ckpt\ner.ckpt-8836

即表明geo模型已经加载成功,该HTTP服务也可以成功调用。

同时加载多个模型的服务

  下面将介绍如何一次性加载三个模型,并能通过HTTP服务调用成功。
  在Tensorflow中,所有操作对象都包装到相应的Session中的,所以想要使用不同的模型就需要将这些模型加载到不同的Session中并在使用的时候申明是哪个Session,从而避免由于Session和想使用的模型不匹配导致的错误。而使用多个graph,就需要为每个graph使用不同的Session,但是每个graph也可以在多个Session中使用,这个时候就需要在每个Session使用的时候明确申明使用的graph。
  在不少文章中都给出了使用样例代码,笔者将给出一次性加载上述三个模型(geo、war、weapon)的代码,如下:

# -*- coding: utf-8 -*-
import json
import pickle

import tornado.httpserver
import tornado.ioloop
import tornado.options
import tornado.web
from tornado.options import define, options

import tensorflow as tf
from utils import create_model, get_logger, load_config
from model import Model
from loader import input_from_line

define("port", default=5005, help="run on the given port", type=int)

with open("geo_maps.pkl", "rb") as f:
    geo_tag_to_id, geo_id_to_tag = pickle.load(f)
with open("geo_maps.pkl", "rb") as f:
    war_tag_to_id, war_id_to_tag = pickle.load(f)
with open("weapon_maps.pkl", "rb") as f:
    weapon_tag_to_id, weapon_id_to_tag = pickle.load(f)

g1 = tf.Graph()
g2 = tf.Graph()
g3 = tf.Graph()

sess1 = tf.Session(graph=g1)
sess2 = tf.Session(graph=g2)
sess3 = tf.Session(graph=g3)

# 加载第一个模型
with sess1.as_default():
    with g1.as_default():
        geo_model = create_model(sess1, Model, "geo_ckpt", load_config("geo_config_file"), get_logger("geo_train.log"))

with sess2.as_default():
    with g2.as_default():
        war_model = create_model(sess2, Model, "war_ckpt", load_config("war_config_file"), get_logger("war_train.log"))

with sess3.as_default():
    with g3.as_default():
        weapon_model = create_model(sess3, Model, "weapon_ckpt", load_config("weapon_config_file"), get_logger("weapon_train.log"))


# Geo模型预测
class GeoHandler(tornado.web.RequestHandler):
    # post函数
    def post(self):
        event = self.get_argument('event')
        with sess1.as_default():
            with sess1.graph.as_default():
                result = geo_model.evaluate_line(sess1, input_from_line(event, 128, geo_tag_to_id), geo_id_to_tag)
        self.write(json.dumps(result, ensure_ascii=False))


# War模型预测
class WarHandler(tornado.web.RequestHandler):
    # post函数
    def post(self):
        event = self.get_argument('event')
        with sess2.as_default():
            with sess2.graph.as_default():
                result = war_model.evaluate_line(sess2, input_from_line(event, 128, war_tag_to_id), war_id_to_tag)
        self.write(json.dumps(result, ensure_ascii=False))


# Weapon模型预测
class WeaponHandler(tornado.web.RequestHandler):
    # post函数
    def post(self):
        event = self.get_argument('event')
        with sess3.as_default():
            with sess3.graph.as_default():
                result = weapon_model.evaluate_line(sess3, input_from_line(event, 128, weapon_tag_to_id), weapon_id_to_tag)
        self.write(json.dumps(result, ensure_ascii=False))


# 主函数
def main():
    # 开启tornado服务
    tornado.options.parse_command_line()
    # 定义app
    app = tornado.web.Application(
            handlers=[(r'/model/entity/geo', GeoHandler),
                      (r'/model/entity/war', WarHandler),
                      (r'/model/entity/weapon', WeaponHandler)],
           )
    http_server = tornado.httpserver.HTTPServer(app)
    http_server.listen(options.port)
    tornado.ioloop.IOLoop.instance().start()

main()

在启动的输出中,我们看到有输出:

2020-10-28 20:41:08,149 - geo_train.log - INFO - Reading model parameters from geo_ckpt\ner.ckpt-8836
2020-10-28 20:41:14,846 - war_train.log - INFO - Reading model parameters from war_ckpt\ner.ckpt-23014
2020-10-28 20:41:22,262 - weapon_train.log - INFO - Reading model parameters from weapon_ckpt\ner.ckpt-28756

这表明该HTTP服务已经成功加载三个模型,并能调用成功。

总结

  本来通过在Tensorflow中创建多个Session,每个Session运行一个graph,在运行时制定默认Session的办法来实现同时加载多个模型,并封装成HTTP服务。
  当然Tensorflow同时加载多个模型的办法不止这一种,后面的文章将会介绍如何使用tensorflow/serving来同时加载多个模型。
  感谢大家的阅读~

参考文献

  1. bertNER项目:https://github.com/yumath/bertNER
  2. Tensorflow同时加载使用多个模型:https://www.cnblogs.com/arkenstone/p/7016481.html
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值