Bert训练MRPC数据集,将预测模型写成API调用,以及简单html界面使用,后台服务使用flask


本文代码已上传到Github: https://github.com/luoyangbiao/bert_flask

1. 用Bert训练MRPC数据集,保存模型

  1. 这里运行官方的run_classifier.py代码就好了,我在Github上的代码也写了个脚本,运行脚本就可以进行训练。
  2. 这里我把GLUE数据集里的MRPC部分也上传到了Github上,模型数据太大需要在google提供的链接下载。
git clone https://github.com/luoyangbiao/bert_flask.git
cd bert_flask
wget https://storage.googleapis.com/bert_models/2018_10_18/uncased_L-12_H-768_A-12.zip
unzip uncased_L-12_H-768_A-12.zip
bash classifier.sh

2.(API 以及简易 html 调用),模型的使用

  1. python bert_api_3.py 是flask的api调用方式,用post方法传入两个句子的 json ,返回模型预测结果。端口号可以改,我这里使用5005。
    在这里插入图片描述
  2. python bert_predi_web.py是使用html界面显示调用,端口号我这里还是使用5005,路径改成了根目录。在浏览器中输入 [公网ip地址:5005] 并回车即可使用模型。写了GET和POST方法。

3. 代码讲解

代码部分分为训练部分和flask调用做预测两个部分,训练部分就是官方的示例代码,训练MRPC数据集。这里主要说一下flask调用做预测部分的代码:

1)API调用方法:
将POST和GET两种方法都放置在 /pred 目录下,API调用时使用json格式传入"sentence1"和"sentence2"两个量。返回结果的json,正确返回时包括不相似0和相似1的结果,发生错误时返回error message。使用5005端口 (需要端口可用)

app = Flask(__name__)
if FLAGS.do_predict:
predict_file = os.path.join(FLAGS.output_dir, "predict.tf_record")
predict_drop_remainder = True if FLAGS.use_tpu else False

@app.route('/pred',methods=['POST','GET'])
def index():
	response = {}
	try:
		data = request.json
		predict_examples = [InputExample('predict',data["sentence1"],data["sentence2"], '0')]
		num_actual_predict_examples = len(predict_examples)
		if FLAGS.use_tpu:
			while len(predict_examples) % FLAGS.predict_batch_size != 0:
				predict_examples.append(PaddingInputExample())
		file_based_convert_examples_to_features(predict_examples,label_list,
										    FLAGS.max_seq_length, tokenizer,predict_file)
		predict_input_fn = file_based_input_fn_builder(input_file=predict_file,
											  seq_length=FLAGS.max_seq_length,
											  is_training=False,
											  drop_remainder=predict_drop_remainder)
		result = estimator.predict(input_fn=predict_input_fn)
		api_return = []
		#list用于以后添加返回多个句子同时计算的功能
		for (i, prediction) in enumerate(result):
			probabilities = prediction["probabilities"]
			api_return.append(probabilities)
		response["prediction_0"] = str(list(api_return[0])[0])
		response["prediction_1"] = str(list(api_return[0])[1])
	except Exception as e:
		response["error_message"] = "An error occur, please read the document or contact me by luoyangbiao@bupt.edu.cn"
	return json.dumps(response)
app.run("0.0.0.0", port=5005, threaded=True)	

2)html界面调用:
写一个html文件用来传参及显示,放在文件夹 /templates 里,将 html 界面用到的图片素材放在文件夹 /static 里。传入两个text类型的字符串变量,预测,将结果返回到html界面的变量中进行显示。 threaded=True 让 flask 可以同时处理多个请求,该功能暂未进行扩展,敬请期待后续博文。路径改到了根目录下 / ,即输入ip地址:端口号即可发送GET请求。

app = Flask(__name__)
if FLAGS.do_predict:
	predict_file = os.path.join(FLAGS.output_dir, "predict.tf_record")
	predict_drop_remainder = True if FLAGS.use_tpu else False

	@app.route('/',methods=['POST','GET'])
	def index():
		response = {}
		try:
			if request.method == "GET":
				return render_template('index.html')
			if request.method == "POST":
				sentence1 = request.form['sentence1']
				sentence2 = request.form['sentence2']
			predict_examples = [InputExample('predict',sentence1,sentence2, '0')]
			if FLAGS.use_tpu:
				while len(predict_examples) % FLAGS.predict_batch_size != 0:
					predict_examples.append(PaddingInputExample())
			file_based_convert_examples_to_features(predict_examples, label_list,FLAGS.max_seq_length, tokenizer,predict_file)
			predict_input_fn = file_based_input_fn_builder(input_file=predict_file,seq_length=FLAGS.max_seq_length,is_training=False,drop_remainder=predict_drop_remainder)
			result = estimator.predict(input_fn=predict_input_fn)
			api_return = []
			for (i, prediction) in enumerate(result):
				probabilities = prediction["probabilities"]
				api_return.append(probabilities)
			pred_0 = round((list(api_return[0])[0])*100,2)
			pred_1 = round((list(api_return[0])[1])*100,2)
			response["prediction_0"] = str(pred_0) + '%'
			response["prediction_1"] = str(pred_1) + '%'
		except Exception as e:
			response["error_message"] = "An error occur, please read the document or contact me by luoyangbiao@bupt.edu.cn"
			return render_template('index.html', RESULT = response["prediction_1"], SENTENCE1 = sentence1, SENTENCE2 = sentence2)
		app.run("0.0.0.0", port=5005, threaded=True)

  • 3
    点赞
  • 13
    收藏
    觉得还不错? 一键收藏
  • 1
    评论
评论 1
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值