Tensorflow+Flask在深度学习模型接口化遇到的问题

+ 报错:Tensor Tensor("dense_2/Softmax:0", shape=(?, 14), dtype=float32) is not an element of this graph.

- 解决:在加载模型之后进行一次predict,随便给出一个变量a调用model.predict()方法,注意a的shape与预测x_train一致。

# 加载模型
model_filePath = 'models/per-rel-06-0.7925.h5'
model = load_model(model_filePath, custom_objects={"Attention": Attention})
# 在加载模型之后进行一次predict,注意a的shape与预测x_train一致
a = np.ones((1, 128, 768))
model.predict(a)

+ 报错:RuntimeError: The Session graph is empty. Add operations to the graph before calling run().
- 解决:声明graph即可。

graph = tf.get_default_graph()
with graph.as_default():
    # 这里把调用包含model.predict()方法的代码放到这里

+ 报错:tensorflow.python.framework.errors_impl.FailedPreconditionError: Error while reading resource variable dense_1/bias from Container: localhost. This could mean that the variable was uninitialized. Not found: Resource localhost/dense_1/bias/class tensorflow::Var does not exist.[[{{node dense_1/BiasAdd/ReadVariableOp}}]]
- 解决:在调用model.predict()方法前进行全局sess和graph的声明。

global sess
global graph
with graph.as_default():
    # 这里把调用包含model.predict()方法的代码放到这里

+ 最终的代码如下:

from flask import Flask, request
from flask_cors import CORS
from keras.models import load_model
from tensorflow.python.keras.backend import set_session
from att import Attention
from bert.extract_feature import BertVector
import json
import numpy as np
import tensorflow as tf


app = Flask(__name__)
# 方便跨域使用
CORS(app, support_credential=True)
# 程序开始时声明
sess = tf.Session()
graph = tf.get_default_graph()
# 在model加载前添加set_session
set_session(sess)
# 加载模型
model_filePath = 'models/per-rel-06-0.7925.h5'
model = load_model(model_filePath, custom_objects={"Attention": Attention})
# 在加载模型之后进行一次predict,注意a的shape与预测x_train一致
a = np.ones((1, 128, 768))
model.predict(a)


@app.route('/relation')
def predict_relation():
    # get请求,得到需要解析的原文和人名
    text1 = request.args.get('text')
    # 通过#进行人名的分割,在text处理时替换$,并用#来替换人名
    per1, per2, doc = text1.split(',')
    text = '$'.join([per1, per2, doc.replace(per1, len(per1) * '#').replace(per2, len(per2) * '#')])
    bert_model = BertVector(pooling_strategy="NONE", max_seq_len=128)
    vec = bert_model.encode([text])["encodes"][0]
    x_train = np.array([vec])
    # 在调用model.predict()方法前进行全局sess和graph的声明
    global sess
    global graph
    with graph.as_default():
        set_session(sess)
        predicted = model.predict(x_train)
        y = np.argmax(predicted[0])
        with open('data/rel_dict.json', 'r', encoding='utf-8') as f:
            rel_dict = json.load(f)
        id_rel_dict = {v: k for k, v in rel_dict.items()}
        res = id_rel_dict[y]
    return '预测人物关系: %s' % res


if __name__ == '__main__':
    app.run('0.0.0.0', port=12345, debug=True)

运行成功~

 

  • 2
    点赞
  • 0
    收藏
    觉得还不错? 一键收藏
  • 2
    评论
评论 2
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值