Bert句子相似度分析,用flask服务提供调用
本文代码已上传到Github: https://github.com/luoyangbiao/bert_flask
1. 用Bert训练MRPC数据集,保存模型
- 这里运行官方的run_classifier.py代码就好了,我在Github上的代码也写了个脚本,运行脚本就可以进行训练。
- 这里我把GLUE数据集里的MRPC部分也上传到了Github上,模型数据太大需要在google提供的链接下载。
git clone https://github.com/luoyangbiao/bert_flask.git
cd bert_flask
wget https://storage.googleapis.com/bert_models/2018_10_18/uncased_L-12_H-768_A-12.zip
unzip uncased_L-12_H-768_A-12.zip
bash classifier.sh
2.(API 以及简易 html 调用),模型的使用
- python bert_api_3.py 是flask的api调用方式,用post方法传入两个句子的 json ,返回模型预测结果。端口号可以改,我这里使用5005。
- python bert_predi_web.py是使用html界面显示调用,端口号我这里还是使用5005,路径改成了根目录。在浏览器中输入 [公网ip地址:5005] 并回车即可使用模型。写了GET和POST方法。
3. 代码讲解
代码部分分为训练部分和flask调用做预测两个部分,训练部分就是官方的示例代码,训练MRPC数据集。这里主要说一下flask调用做预测部分的代码:
1)API调用方法:
将POST和GET两种方法都放置在 /pred 目录下,API调用时使用json格式传入"sentence1"和"sentence2"两个量。返回结果的json,正确返回时包括不相似0和相似1的结果,发生错误时返回error message。使用5005端口 (需要端口可用)
app = Flask(__name__)
if FLAGS.do_predict:
predict_file = os.path.join(FLAGS.output_dir, "predict.tf_record")
predict_drop_remainder = True if FLAGS.use_tpu else False
@app.route('/pred',methods=['POST','GET'])
def index():
response = {}
try:
data = request.json
predict_examples = [InputExample('predict',data["sentence1"],data["sentence2"], '0')]
num_actual_predict_examples = len(predict_examples)
if FLAGS.use_tpu:
while len(predict_examples) % FLAGS.predict_batch_size != 0:
predict_examples.append(PaddingInputExample())
file_based_convert_examples_to_features(predict_examples,label_list,
FLAGS.max_seq_length, tokenizer,predict_file)
predict_input_fn = file_based_input_fn_builder(input_file=predict_file,
seq_length=FLAGS.max_seq_length,
is_training=False,
drop_remainder=predict_drop_remainder)
result = estimator.predict(input_fn=predict_input_fn)
api_return = []
#list用于以后添加返回多个句子同时计算的功能
for (i, prediction) in enumerate(result):
probabilities = prediction["probabilities"]
api_return.append(probabilities)
response["prediction_0"] = str(list(api_return[0])[0])
response["prediction_1"] = str(list(api_return[0])[1])
except Exception as e:
response["error_message"] = "An error occur, please read the document or contact me by luoyangbiao@bupt.edu.cn"
return json.dumps(response)
app.run("0.0.0.0", port=5005, threaded=True)
2)html界面调用:
写一个html文件用来传参及显示,放在文件夹 /templates 里,将 html 界面用到的图片素材放在文件夹 /static 里。传入两个text类型的字符串变量,预测,将结果返回到html界面的变量中进行显示。 threaded=True 让 flask 可以同时处理多个请求,该功能暂未进行扩展,敬请期待后续博文。路径改到了根目录下 / ,即输入ip地址:端口号即可发送GET请求。
app = Flask(__name__)
if FLAGS.do_predict:
predict_file = os.path.join(FLAGS.output_dir, "predict.tf_record")
predict_drop_remainder = True if FLAGS.use_tpu else False
@app.route('/',methods=['POST','GET'])
def index():
response = {}
try:
if request.method == "GET":
return render_template('index.html')
if request.method == "POST":
sentence1 = request.form['sentence1']
sentence2 = request.form['sentence2']
predict_examples = [InputExample('predict',sentence1,sentence2, '0')]
if FLAGS.use_tpu:
while len(predict_examples) % FLAGS.predict_batch_size != 0:
predict_examples.append(PaddingInputExample())
file_based_convert_examples_to_features(predict_examples, label_list,FLAGS.max_seq_length, tokenizer,predict_file)
predict_input_fn = file_based_input_fn_builder(input_file=predict_file,seq_length=FLAGS.max_seq_length,is_training=False,drop_remainder=predict_drop_remainder)
result = estimator.predict(input_fn=predict_input_fn)
api_return = []
for (i, prediction) in enumerate(result):
probabilities = prediction["probabilities"]
api_return.append(probabilities)
pred_0 = round((list(api_return[0])[0])*100,2)
pred_1 = round((list(api_return[0])[1])*100,2)
response["prediction_0"] = str(pred_0) + '%'
response["prediction_1"] = str(pred_1) + '%'
except Exception as e:
response["error_message"] = "An error occur, please read the document or contact me by luoyangbiao@bupt.edu.cn"
return render_template('index.html', RESULT = response["prediction_1"], SENTENCE1 = sentence1, SENTENCE2 = sentence2)
app.run("0.0.0.0", port=5005, threaded=True)