+ 报错:Tensor Tensor("dense_2/Softmax:0", shape=(?, 14), dtype=float32) is not an element of this graph.
- 解决:在加载模型之后进行一次predict,随便给出一个变量a调用model.predict()方法,注意a的shape与预测x_train一致。
# 加载模型
model_filePath = 'models/per-rel-06-0.7925.h5'
model = load_model(model_filePath, custom_objects={"Attention": Attention})
# 在加载模型之后进行一次predict,注意a的shape与预测x_train一致
a = np.ones((1, 128, 768))
model.predict(a)
+ 报错:RuntimeError: The Session graph is empty. Add operations to the graph before calling run().
- 解决:声明graph即可。
graph = tf.get_default_graph()
with graph.as_default():
# 这里把调用包含model.predict()方法的代码放到这里
+ 报错:tensorflow.python.framework.errors_impl.FailedPreconditionError: Error while reading resource variable dense_1/bias from Container: localhost. This could mean that the variable was uninitialized. Not found: Resource localhost/dense_1/bias/class tensorflow::Var does not exist.[[{{node dense_1/BiasAdd/ReadVariableOp}}]]
- 解决:在调用model.predict()方法前进行全局sess和graph的声明。
global sess
global graph
with graph.as_default():
# 这里把调用包含model.predict()方法的代码放到这里
+ 最终的代码如下:
from flask import Flask, request
from flask_cors import CORS
from keras.models import load_model
from tensorflow.python.keras.backend import set_session
from att import Attention
from bert.extract_feature import BertVector
import json
import numpy as np
import tensorflow as tf
app = Flask(__name__)
# 方便跨域使用
CORS(app, support_credential=True)
# 程序开始时声明
sess = tf.Session()
graph = tf.get_default_graph()
# 在model加载前添加set_session
set_session(sess)
# 加载模型
model_filePath = 'models/per-rel-06-0.7925.h5'
model = load_model(model_filePath, custom_objects={"Attention": Attention})
# 在加载模型之后进行一次predict,注意a的shape与预测x_train一致
a = np.ones((1, 128, 768))
model.predict(a)
@app.route('/relation')
def predict_relation():
# get请求,得到需要解析的原文和人名
text1 = request.args.get('text')
# 通过#进行人名的分割,在text处理时替换$,并用#来替换人名
per1, per2, doc = text1.split(',')
text = '$'.join([per1, per2, doc.replace(per1, len(per1) * '#').replace(per2, len(per2) * '#')])
bert_model = BertVector(pooling_strategy="NONE", max_seq_len=128)
vec = bert_model.encode([text])["encodes"][0]
x_train = np.array([vec])
# 在调用model.predict()方法前进行全局sess和graph的声明
global sess
global graph
with graph.as_default():
set_session(sess)
predicted = model.predict(x_train)
y = np.argmax(predicted[0])
with open('data/rel_dict.json', 'r', encoding='utf-8') as f:
rel_dict = json.load(f)
id_rel_dict = {v: k for k, v in rel_dict.items()}
res = id_rel_dict[y]
return '预测人物关系: %s' % res
if __name__ == '__main__':
app.run('0.0.0.0', port=12345, debug=True)
运行成功~