Keras入门(七)使用Flask+Keras-bert构建模型预测服务

  在文章NLP(三十四)使用keras-bert实现序列标注任务中,我们介绍了如何使用keras-bert模块,利用BERT中文预训练模型来实现序列标注任务的模型训练、模型评估和模型预测。其中,模型预测是通过加载生成的h5文件来实现的。
  本文将会介绍如何使用Flask构建模型预测的HTTP服务。
  我们遵循正常的思路,即先使用Keras加载保存后的h5模型文件,利用Flask对新输入的文本进行模型预测,最后给出预测结果。我们对人民日报命名实体实体数据集进行模型训练,采用文章NLP(三十四)使用keras-bert实现序列标注任务中的模型,训练后得到example_ner.h5文件,模型预测的HTTP服务脚本如下:

# -*- coding: utf-8 -*-
import json
import traceback
import numpy as np
from keras.models import load_model
from keras_bert import get_custom_objects
from keras_contrib.layers import CRF
from keras_contrib.losses import crf_loss
from keras_contrib.metrics import crf_accuracy
from flask import Flask, request

from model_train import PreProcessInputData, id_label_dict


# 将BIO标签转化为方便阅读的json格式
def bio_to_json(string, tags):
    item = {"string": string, "entities": []}
    entity_name = ""
    entity_start = 0
    iCount = 0
    entity_tag = ""

    for c_idx in range(min(len(string), len(tags))):
        c, tag = string[c_idx], tags[c_idx]
        if c_idx < len(tags)-1:
            tag_next = tags[c_idx+1]
        else:
            tag_next = ''

        if tag[0] == 'B':
            entity_tag = tag[2:]
            entity_name = c
            entity_start = iCount
            if tag_next[2:] != entity_tag:
                item["entities"].append({"word": c, "start": iCount, "end": iCount + 1, "type": tag[2:]})
        elif tag[0] == "I":
            if tag[2:] != tags[c_idx-1][2:] or tags[c_idx-1][2:] == 'O':
                tags[c_idx] = 'O'
                pass
            else:
                entity_name = entity_name + c
                if tag_next[2:] != entity_tag:
                    item["entities"].append({"word": entity_name, "start": entity_start, "end": iCount + 1, "type": entity_tag})
                    entity_name = ''
        iCount += 1
    return item


app = Flask(__name__)


@app.route("/model/ner", methods=["GET", "POST"])
def get_geo():
    return_result = {"code": 200, "message": "success", "data": []}
    try:
        text = request.get_json()["text"].replace(" ", "")
        word_labels, seq_types = PreProcessInputData([text])

        # 模型预测
        predicted = ner_model.predict([word_labels, seq_types])
        y = np.argmax(predicted[0], axis=1)
        tag = [id_label_dict[_] for _ in y]

        # 输出预测结果
        result = bio_to_json(text, tag[1:-1])
        return_result["data"] = result

    except Exception:
        return_result["code"] = 400
        return_result["message"] = traceback.format_exc()

    return json.dumps(return_result, ensure_ascii=False, indent=2)


if __name__ == '__main__':
    # 加载训练好的模型
    custom_objects = get_custom_objects()
    for key, value in {'CRF': CRF, 'crf_loss': crf_loss, 'crf_accuracy': crf_accuracy}.items():
        custom_objects[key] = value
    ner_model = load_model("example_ner.h5", custom_objects=custom_objects)
    # 启动HTTP服务
    app.run(host="0.0.0.0", port=25000)

看上去上面的服务并没有什么问题,但当我们进行HTTP请求时,报错如下:

File "/home/jclian91/.conda/envs/py3-lmj/lib/python3.6/site-packages/tensorflow/python/framework/ops.py", line 3875, in _as_graph_element_locked
    raise ValueError("Tensor %s is not an element of this graph." % obj)
ValueError: Tensor Tensor("crf_1/cond/Merge:0", shape=(?, ?, 7), dtype=float32) is not an element of this graph.

上网搜资料,发现这种错误非常常见,其中一种解决方法如下:

导入模块:

import tensorflow as tf
from keras.backend import set_session

在加载模型(load_model)的代码前,加几行代码如下:

	sess = tf.Session()
    graph = tf.get_default_graph()
    set_session(sess)

同时在HTTP服务的模型预测(ner_model.predict)前,加几行代码如下:

	# 模型预测
	global sess
	global graph
	with graph.as_default():
	    set_session(sess)
	    predicted = ner_model.predict([word_labels, seq_types])

这样再次启动模型预测HTTP脚本,可以发现模型预测的HTTP请求是正常的。

$ curl --location --request POST 'http://192.168.1.193:25000/model/ner' \
> --header 'Content-Type: application/json' \
> --data-raw '{
>     "text": "美国卫生部长阿扎尔辞职 原因曝光"
> }'
{
  "code": 200,
  "message": "success",
  "data": {
    "string": "美国卫生部长阿扎尔辞职原因曝光",
    "entities": [
      {
        "word": "美国卫生部",
        "start": 0,
        "end": 5,
        "type": "ORG"
      },
      {
        "word": "阿扎尔",
        "start": 6,
        "end": 9,
        "type": "PER"
      }
    ]
  }
}

  该脚本已上传至Github,网址为:https://github.com/percent4/keras_bert_sequence_labeling/blob/master/model_server.py
  感谢阅读~
  2021年1月16日于上海浦东

  • 1
    点赞
  • 6
    收藏
    觉得还不错? 一键收藏
  • 2
    评论

“相关推荐”对你有帮助么?

  • 非常没帮助
  • 没帮助
  • 一般
  • 有帮助
  • 非常有帮助
提交
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值